I have a simple negative test that starts a server with out security and sends a request to the server. Then stops the server and restarts it with SSL. Same http client is used to send the request again, it gets good response (there by failing the test case). It gets a good response even if the server is stopped.
After further investigation I have found that if the response of the first request is NOT read using ioutil.ReadAll, then an error is returned for the second request. It seems like calling ioutil.ReadAll keeps the connection alive. I have looked at the source code but could not find the root cause.
Any idea what might be happening?
My test is as follows:
package main
import (
"fmt"
"testing"
)
func TestGet(t *testing.T) {
port = 8839
url := "http://localhost:8839/hello"
// start server without security
startServer(false)
// Get a non-secure client (without tls credentials)
netClient := getClient()
fmt.Println("Sending request to ", url)
response, err := netClient.Get(url)
if err != nil {
t.Fatalf("Error sending request: %s", err)
}
fmt.Println("Response: ", response)
// If you comment this line test will succeed else it fails
readResponse(response)
stopServer()
// start server with security
startServer(true)
defer func() {
stopServer()
}()
// use the same client to send request to secure server
fmt.Println("Sending request to ", url)
response, err = netClient.Get(url)
if err == nil {
t.Error("Expected failure sending a request to a SSL server from a non-SSL client")
fmt.Println("Response: ", response)
readResponse(response)
} else {
fmt.Println("Expected error: ", err)
}
}
The code is:
package main
import (
"crypto/tls"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"strconv"
"time"
)
// only needed below for sample processing
var listener net.Listener
var port int
func startHTTPServer(secure bool) {
mux := http.NewServeMux()
mux.Handle("/hello", &handler{})
addr := net.JoinHostPort("0.0.0.0", strconv.Itoa(port))
if secure {
config, err := getTLSConfig("server.pem", "server.key")
if err != nil {
log.Fatal("tls config: ", err)
}
listener, err = tls.Listen("tcp", addr, config)
if err != nil {
log.Fatal("listen: ", err)
}
} else {
listener, _ = net.Listen("tcp", addr)
}
http.Serve(listener, mux)
}
func getTLSConfig(certfile, keyfile string) (*tls.Config, error) {
cer, err := tls.LoadX509KeyPair("server.pem", "server.key")
if err != nil {
return nil, err
}
config := &tls.Config{
Certificates: []tls.Certificate{cer},
ClientAuth: tls.NoClientCert,
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS12,
}
return config, nil
}
func stopServer() {
fmt.Println("Stopping server")
if listener != nil {
listener.Close()
}
fmt.Println("Sleeping 5 seconds for server to stop")
time.Sleep(time.Second * 5)
}
func startServer(secure bool) {
if secure {
fmt.Println("Starting https server")
} else {
fmt.Println("Starting http server")
}
go func() {
startHTTPServer(secure)
}()
fmt.Println("Sleeping 5 seconds for server to start")
time.Sleep(time.Second * 5)
}
type handler struct{}
func (h *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.Write([]byte("This is an example server.\n"))
}
func getClient() *http.Client {
tr := new(http.Transport)
//rootCAPool := x509.NewCertPool()
//cacert, err := ioutil.ReadFile("root.pem")
//if err != nil {
// fmt.Printf("Failed to read '%s': %s\n", cacert, err)
//}
//ok := rootCAPool.AppendCertsFromPEM(cacert)
//if !ok {
// fmt.Printf("Failed to process certificate from file %s\n", cacert)
//}
//tr.TLSClientConfig = &tls.Config{
// RootCAs: rootCAPool,
//}
var netClient = &http.Client{Transport: tr, Timeout: time.Second * 10}
return netClient
}
func getRequest(url string) (*http.Request, error) {
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return nil, err
}
return req, nil
}
func readResponse(resp *http.Response) error {
if resp.Body != nil {
respBody, err := ioutil.ReadAll(resp.Body)
defer func() {
err := resp.Body.Close()
if err != nil {
fmt.Printf("Failed to close the response body: %s", err.Error())
}
}()
if err != nil {
return err
}
fmt.Println("Response: ", respBody)
}
return nil
}
func main() {
port = 8839
startServer(false)
url := "http://localhost:8839/hello"
netClient := getClient()
fmt.Println("Sending request to ", url)
req, err := getRequest(url)
resp, err := netClient.Do(req)
if err != nil {
fmt.Println("Error was received: ", err)
os.Exit(1)
}
readResponse(resp)
stopServer()
startServer(true)
fmt.Println("Sending request to ", url)
resp, err = netClient.Do(req)
if err != nil {
fmt.Println("Error was received: ", err)
os.Exit(1)
}
readResponse(resp)
}