Fix a race detector error in a unit test

Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
Monis Khan 2021-03-10 11:24:42 -08:00 committed by Ryan Richard
parent 0b300cbe42
commit 6582c23edb
2 changed files with 67 additions and 52 deletions

View File

@ -357,7 +357,7 @@ func (c *impersonatorConfigController) ensureImpersonatorIsStarted(syncCtx contr
// The server has stopped, so finish shutting it down. // The server has stopped, so finish shutting it down.
// If that fails too, return both errors for logging purposes. // If that fails too, return both errors for logging purposes.
// By returning an error, the sync function will be called again // By returning an error, the sync function will be called again
// and we'll have a change to restart the server. // and we'll have a chance to restart the server.
close(c.errorCh) // We don't want ensureImpersonatorIsStopped to block on reading this channel. close(c.errorCh) // We don't want ensureImpersonatorIsStopped to block on reading this channel.
stoppingErr := c.ensureImpersonatorIsStopped(false) stoppingErr := c.ensureImpersonatorIsStopped(false)
return errors.NewAggregate([]error{runningErr, stoppingErr}) return errors.NewAggregate([]error{runningErr, stoppingErr})

View File

@ -289,14 +289,15 @@ func TestImpersonatorConfigControllerSync(t *testing.T) {
var cancelContext context.Context var cancelContext context.Context
var cancelContextCancelFunc context.CancelFunc var cancelContextCancelFunc context.CancelFunc
var syncContext *controllerlib.Context var syncContext *controllerlib.Context
var impersonatorFuncWasCalled int
var impersonatorFuncError error
var impersonatorFuncReturnedFuncError error
var startedTLSListener net.Listener
var frozenNow time.Time var frozenNow time.Time
var signingCertProvider dynamiccert.Provider var signingCertProvider dynamiccert.Provider
var signingCACertPEM, signingCAKeyPEM []byte var signingCACertPEM, signingCAKeyPEM []byte
var signingCASecret *corev1.Secret var signingCASecret *corev1.Secret
var impersonatorFuncWasCalled int
var impersonatorFuncError error
var impersonatorFuncReturnedFuncError error
var startedTLSListener net.Listener
var startedTLSListenerMutex sync.RWMutex
var testHTTPServer *http.Server var testHTTPServer *http.Server
var testHTTPServerMutex sync.RWMutex var testHTTPServerMutex sync.RWMutex
var testHTTPServerInterruptCh chan struct{} var testHTTPServerInterruptCh chan struct{}
@ -317,6 +318,48 @@ func TestImpersonatorConfigControllerSync(t *testing.T) {
return nil, impersonatorFuncError return nil, impersonatorFuncError
} }
startedTLSListenerMutex.Lock() // this is to satisfy the race detector
defer startedTLSListenerMutex.Unlock()
var err error
// Bind a listener to the port. Automatically choose the port for unit tests instead of using the real port.
startedTLSListener, err = tls.Listen("tcp", localhostIP+":0", &tls.Config{
MinVersion: tls.VersionTLS12,
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
certPEM, keyPEM := dynamicCertProvider.CurrentCertKeyContent()
if certPEM != nil && keyPEM != nil {
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
r.NoError(err)
return &tlsCert, nil
}
return nil, nil // no cached TLS certs
},
ClientAuth: tls.RequestClientCert,
VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
// Docs say that this will always be called in tls.RequestClientCert mode
// and that the second parameter will always be nil in that case.
// rawCerts will be raw ASN.1 certificates provided by the peer.
if len(rawCerts) != 1 {
return fmt.Errorf("expected to get one client cert on incoming request to test server")
}
clientCert := rawCerts[0]
currentClientCertCA := impersonationProxySignerCAProvider.CurrentCABundleContent()
if currentClientCertCA == nil {
return fmt.Errorf("impersonationProxySignerCAProvider does not have a current CA certificate")
}
// Assert that the client's cert was signed by the CA cert that the controller put into
// the CAContentProvider that was passed in.
parsed, err := x509.ParseCertificate(clientCert)
require.NoError(t, err)
roots := x509.NewCertPool()
require.True(t, roots.AppendCertsFromPEM(currentClientCertCA))
opts := x509.VerifyOptions{Roots: roots}
_, err = parsed.Verify(opts)
require.NoError(t, err)
return nil
},
})
r.NoError(err)
// Return a func that starts a fake server when called, and shuts down the fake server when stopCh is closed. // Return a func that starts a fake server when called, and shuts down the fake server when stopCh is closed.
// This fake server is enough like the real impersonation proxy server for this unit test because it // This fake server is enough like the real impersonation proxy server for this unit test because it
// uses the supplied providers to serve TLS. The goal of this unit test is to make sure that the server // uses the supplied providers to serve TLS. The goal of this unit test is to make sure that the server
@ -326,47 +369,6 @@ func TestImpersonatorConfigControllerSync(t *testing.T) {
return impersonatorFuncReturnedFuncError return impersonatorFuncReturnedFuncError
} }
var err error
// automatically choose the port for unit tests
startedTLSListener, err = tls.Listen("tcp", localhostIP+":0", &tls.Config{
MinVersion: tls.VersionTLS12,
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
certPEM, keyPEM := dynamicCertProvider.CurrentCertKeyContent()
if certPEM != nil && keyPEM != nil {
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
r.NoError(err)
return &tlsCert, nil
}
return nil, nil // no cached TLS certs
},
ClientAuth: tls.RequestClientCert,
VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
// Docs say that this will always be called in tls.RequestClientCert mode
// and that the second parameter will always be nil in that case.
// rawCerts will be raw ASN.1 certificates provided by the peer.
if rawCerts == nil || len(rawCerts) != 1 {
return fmt.Errorf("expected to get one client cert on incoming request to test server")
}
clientCert := rawCerts[0]
currentClientCertCA := impersonationProxySignerCAProvider.CurrentCABundleContent()
if currentClientCertCA == nil {
return fmt.Errorf("impersonationProxySignerCAProvider does not have a current CA certificate")
}
// Assert that the client's cert was signed by the CA cert that the controller put into
// the CAContentProvider that was passed in.
parsed, err := x509.ParseCertificate(clientCert)
require.NoError(t, err)
t.Log("PARSED CLIENT CERT")
roots := x509.NewCertPool()
require.True(t, roots.AppendCertsFromPEM(currentClientCertCA))
opts := x509.VerifyOptions{Roots: roots}
_, err = parsed.Verify(opts)
require.NoError(t, err)
return nil
},
})
r.NoError(err)
testHTTPServerMutex.Lock() // this is to satisfy the race detector testHTTPServerMutex.Lock() // this is to satisfy the race detector
testHTTPServer = &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { testHTTPServer = &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
_, err := fmt.Fprint(w, fakeServerResponseBody) _, err := fmt.Fprint(w, fakeServerResponseBody)
@ -376,7 +378,10 @@ func TestImpersonatorConfigControllerSync(t *testing.T) {
// Start serving requests in the background. // Start serving requests in the background.
go func() { go func() {
err := testHTTPServer.Serve(startedTLSListener) startedTLSListenerMutex.RLock() // this is to satisfy the race detector
listener := startedTLSListener
startedTLSListenerMutex.RUnlock()
err := testHTTPServer.Serve(listener)
if !errors.Is(err, http.ErrServerClosed) { if !errors.Is(err, http.ErrServerClosed) {
t.Log("Got an unexpected error while starting the fake http server!") t.Log("Got an unexpected error while starting the fake http server!")
r.NoError(err) // causes the test to crash, which is good enough because this should never happen r.NoError(err) // causes the test to crash, which is good enough because this should never happen
@ -394,7 +399,7 @@ func TestImpersonatorConfigControllerSync(t *testing.T) {
<-testHTTPServerInterruptCh <-testHTTPServerInterruptCh
} }
err = testHTTPServer.Close() err := testHTTPServer.Close()
t.Log("Got an unexpected error while stopping the fake http server!") t.Log("Got an unexpected error while stopping the fake http server!")
r.NoError(err) // causes the test to crash, which is good enough because this should never happen r.NoError(err) // causes the test to crash, which is good enough because this should never happen
@ -403,11 +408,15 @@ func TestImpersonatorConfigControllerSync(t *testing.T) {
} }
var testServerAddr = func() string { var testServerAddr = func() string {
var listener net.Listener
require.Eventually(t, func() bool { require.Eventually(t, func() bool {
return startedTLSListener != nil startedTLSListenerMutex.RLock() // this is to satisfy the race detector
listener = startedTLSListener
defer startedTLSListenerMutex.RUnlock()
return listener != nil
}, 20*time.Second, 50*time.Millisecond, "TLS listener never became not nil") }, 20*time.Second, 50*time.Millisecond, "TLS listener never became not nil")
return startedTLSListener.Addr().String() return listener.Addr().String()
} }
var closeTestHTTPServer = func() { var closeTestHTTPServer = func() {
@ -1967,9 +1976,15 @@ func TestImpersonatorConfigControllerSync(t *testing.T) {
startInformersAndController() startInformersAndController()
// The failure happens in a background goroutine, so the first sync succeeds. // The failure happens in a background goroutine, so the first sync succeeds.
r.NoError(runControllerSync()) r.NoError(runControllerSync())
// Eventually the server is not really running, because the startup failed. // The imperonatorFunc was called to construct an impersonator.
r.Nil(startedTLSListener)
r.Equal(impersonatorFuncWasCalled, 1) r.Equal(impersonatorFuncWasCalled, 1)
// Without waiting too long because we don't want the test to be slow, check if it seems like the
// server never started.
r.Never(func() bool {
testHTTPServerMutex.RLock() // this is to satisfy the race detector
defer testHTTPServerMutex.RUnlock()
return testHTTPServer != nil
}, 2*time.Second, 50*time.Millisecond)
r.Len(kubeAPIClient.Actions(), 3) r.Len(kubeAPIClient.Actions(), 3)
requireNodesListed(kubeAPIClient.Actions()[0]) requireNodesListed(kubeAPIClient.Actions()[0])
requireLoadBalancerWasCreated(kubeAPIClient.Actions()[1]) requireLoadBalancerWasCreated(kubeAPIClient.Actions()[1])