// Copyright 2021 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 package testlib import ( "context" "errors" "fmt" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/kubernetes" "go.pinniped.dev/internal/constable" ) type ( // loopTestingT records the failures observed during an iteration of the RequireEventually() loop. loopTestingT []assertionFailure // assertionFailure is a single error observed during an iteration of the RequireEventually() loop. assertionFailure struct { format string args []interface{} } ) // loopTestingT implements require.TestingT: var _ require.TestingT = (*loopTestingT)(nil) // Errorf is called by the assert.Assertions methods to record an error. func (e *loopTestingT) Errorf(format string, args ...interface{}) { *e = append(*e, assertionFailure{format, args}) } const errLoopFailNow = constable.Error("failing test now") // FailNow is called by the require.Assertions methods to force the code to immediately halt. It panics with a // sentinel value that is recovered by recoverLoopFailNow(). func (e *loopTestingT) FailNow() { panic(errLoopFailNow) } // ignoreFailNowPanic catches the panic from FailNow() and ignores it (allowing the FailNow() call to halt the test // but let the retry loop continue. func recoverLoopFailNow() { switch p := recover(); p { case nil, errLoopFailNow: // Ignore nil (success) and our sentinel value. return default: // Re-panic on any other value. panic(p) } } func RequireEventuallyf( t *testing.T, f func(requireEventually *require.Assertions), waitFor time.Duration, tick time.Duration, msg string, args ...interface{}, ) { t.Helper() RequireEventually(t, f, waitFor, tick, fmt.Sprintf(msg, args...)) } // RequireEventually is similar to require.Eventually() except that it is thread safe and provides a richer way to // write per-iteration assertions. func RequireEventually( t *testing.T, f func(requireEventually *require.Assertions), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}, ) { t.Helper() // Set up some bookkeeping so we can fail with a nice message if necessary. var ( startTime = time.Now() attempts int mostRecentFailures loopTestingT ) // Run the check until it completes with no assertion failures. waitErr := wait.PollImmediate(tick, waitFor, func() (bool, error) { t.Helper() attempts++ // Reset the recorded failures on each iteration. mostRecentFailures = nil // Ignore any panics caused by FailNow() -- they will cause the f() to return immediately but any errors // they've logged should be in mostRecentFailures. defer recoverLoopFailNow() // Run the per-iteration check, recording any failed assertions into mostRecentFailures. f(require.New(&mostRecentFailures)) // We're only done iterating if no assertions have failed. return len(mostRecentFailures) == 0, nil }) // If things eventually completed with no failures/timeouts, we're done. if waitErr == nil { return } // Re-assert the most recent set of failures with a nice error log. duration := time.Since(startTime).Round(100 * time.Millisecond) t.Errorf("failed to complete even after %s (%d attempts): %v", duration, attempts, waitErr) for _, failure := range mostRecentFailures { t.Errorf(failure.format, failure.args...) } // Fail the test now with the provided message. require.NoError(t, waitErr, msgAndArgs...) } // RequireEventuallyWithoutError is similar to require.Eventually() except that it also allows the caller to // return an error from the condition function. If the condition function returns an error at any // point, the assertion will immediately fail. func RequireEventuallyWithoutError( t *testing.T, f func() (bool, error), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}, ) { t.Helper() require.NoError(t, wait.PollImmediate(tick, waitFor, f), msgAndArgs...) } // RequireNeverWithoutError is similar to require.Never() except that it also allows the caller to // return an error from the condition function. If the condition function returns an error at any // point, the assertion will immediately fail. func RequireNeverWithoutError( t *testing.T, f func() (bool, error), waitFor time.Duration, tick time.Duration, msgAndArgs ...interface{}, ) { t.Helper() err := wait.PollImmediate(tick, waitFor, f) isWaitTimeout := errors.Is(err, wait.ErrWaitTimeout) if err != nil && !isWaitTimeout { require.NoError(t, err, msgAndArgs...) // this will fail and throw the right error message } if err == nil { // This prints the same error message that require.Never would print in this case. require.Fail(t, "Condition satisfied", msgAndArgs...) } } // assertNoRestartsDuringTest allows a caller to assert that there were no restarts for a Pod in the // provided namespace with the provided labelSelector during the lifetime of a test. func assertNoRestartsDuringTest(t *testing.T, namespace, labelSelector string) { t.Helper() kubeClient := NewKubernetesClientset(t) ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() previousRestartCounts := getRestartCounts(ctx, t, kubeClient, namespace, labelSelector) t.Cleanup(func() { ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() currentRestartCounts := getRestartCounts(ctx, t, kubeClient, namespace, labelSelector) for key, previousRestartCount := range previousRestartCounts { currentRestartCount, ok := currentRestartCounts[key] // If the container no longer exists, that's a test failure. if !assert.Truef( t, ok, "container %s existed at beginning of the test, but not the end", key.String(), ) { continue } // Expect the restart count to be the same as it was before the test. assert.Equal( t, previousRestartCount, currentRestartCount, "container %s has restarted %d times (original count was %d)", key.String(), currentRestartCount, previousRestartCount, ) } }) } type containerRestartKey struct { namespace string pod string container string } func (k containerRestartKey) String() string { return fmt.Sprintf("%s/%s/%s", k.namespace, k.pod, k.container) } type containerRestartMap map[containerRestartKey]int32 func getRestartCounts(ctx context.Context, t *testing.T, kubeClient kubernetes.Interface, namespace, labelSelector string) containerRestartMap { t.Helper() pods, err := kubeClient.CoreV1().Pods(namespace).List(ctx, metav1.ListOptions{LabelSelector: labelSelector}) require.NoError(t, err) restartCounts := make(containerRestartMap) for _, pod := range pods.Items { for _, container := range pod.Status.ContainerStatuses { key := containerRestartKey{ namespace: pod.Namespace, pod: pod.Name, container: container.Name, } restartCounts[key] = container.RestartCount } } return restartCounts }