diff --git a/internal/controller/authenticator/jwtcachefiller/jwtcachefiller_test.go b/internal/controller/authenticator/jwtcachefiller/jwtcachefiller_test.go index 3fdb7faf..6cc2bc76 100644 --- a/internal/controller/authenticator/jwtcachefiller/jwtcachefiller_test.go +++ b/internal/controller/authenticator/jwtcachefiller/jwtcachefiller_test.go @@ -486,7 +486,7 @@ func TestController(t *testing.T) { authenticated bool err error ) - _ = wait.PollImmediate(10*time.Millisecond, 5*time.Second, func() (bool, error) { + _ = wait.PollUntilContextTimeout(context.Background(), 10*time.Millisecond, 5*time.Second, true, func(ctx context.Context) (bool, error) { rsp, authenticated, err = cachedAuthenticator.AuthenticateToken(context.Background(), jwt) return !isNotInitialized(err), nil }) diff --git a/internal/dynamiccert/provider_test.go b/internal/dynamiccert/provider_test.go index df744fc0..9dc05b70 100644 --- a/internal/dynamiccert/provider_test.go +++ b/internal/dynamiccert/provider_test.go @@ -4,6 +4,7 @@ package dynamiccert import ( + "context" "crypto/tls" "crypto/x509" "net" @@ -194,7 +195,7 @@ func TestProviderWithDynamicServingCertificateController(t *testing.T) { var lastTLSConfig *tls.Config // it will take some time for the controller to catch up - err = wait.PollImmediate(time.Second, 30*time.Second, func() (bool, error) { + err = wait.PollUntilContextTimeout(context.Background(), time.Second, 30*time.Second, true, func(ctx context.Context) (bool, error) { actualTLSConfig, err := tlsConfig.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "force-standard-sni"}) if err != nil { return false, err diff --git a/test/testlib/assertions.go b/test/testlib/assertions.go index 877fdcfa..1178ed0e 100644 --- a/test/testlib/assertions.go +++ b/test/testlib/assertions.go @@ -1,11 +1,10 @@ -// Copyright 2021-2022 the Pinniped contributors. All Rights Reserved. +// Copyright 2021-2023 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 package testlib import ( "context" - "errors" "fmt" "testing" "time" @@ -88,7 +87,7 @@ func RequireEventually( ) // Run the check until it completes with no assertion failures. - waitErr := wait.PollImmediate(tick, waitFor, func() (bool, error) { + waitErr := wait.PollUntilContextTimeout(context.Background(), tick, waitFor, true, func(ctx context.Context) (bool, error) { t.Helper() attempts++ @@ -133,7 +132,10 @@ func RequireEventuallyWithoutError( msgAndArgs ...interface{}, ) { t.Helper() - require.NoError(t, wait.PollImmediate(tick, waitFor, f), msgAndArgs...) + // This previously used wait.PollImmediate (now deprecated), which did not take a ctx arg in the func. + // Hide this detail from the callers for now to keep the old signature. + fWithCtx := func(ctx context.Context) (bool, error) { return f() } + require.NoError(t, wait.PollUntilContextTimeout(context.Background(), tick, waitFor, true, fWithCtx), msgAndArgs...) } // RequireNeverWithoutError is similar to require.Never() except that it also allows the caller to @@ -147,9 +149,11 @@ func RequireNeverWithoutError( msgAndArgs ...interface{}, ) { t.Helper() - err := wait.PollImmediate(tick, waitFor, f) - isWaitTimeout := errors.Is(err, wait.ErrWaitTimeout) - if err != nil && !isWaitTimeout { + // This previously used wait.PollImmediate (now deprecated), which did not take a ctx arg in the func. + // Hide this detail from the callers for now to keep the old signature. + fWithCtx := func(ctx context.Context) (bool, error) { return f() } + err := wait.PollUntilContextTimeout(context.Background(), tick, waitFor, true, fWithCtx) + if err != nil && !wait.Interrupted(err) { require.NoError(t, err, msgAndArgs...) // this will fail and throw the right error message } if err == nil {