// Copyright 2021-2023 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 package dynamiccert import ( "context" "crypto/tls" "crypto/x509" "net" "reflect" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/apiserver/pkg/server/dynamiccertificates" "k8s.io/apiserver/pkg/storage/names" "go.pinniped.dev/internal/certauthority" "go.pinniped.dev/internal/crypto/ptls" "go.pinniped.dev/test/testlib" ) func TestProviderWithDynamicServingCertificateController(t *testing.T) { t.Parallel() tests := []struct { name string f func(t *testing.T, ca Provider, certKey Private) (wantClientCASubjects [][]byte, wantCerts []tls.Certificate) }{ { name: "no-op leave everything alone", f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { pool := x509.NewCertPool() ok := pool.AppendCertsFromPEM(ca.CurrentCABundleContent()) require.True(t, ok, "should have valid non-empty CA bundle") certPEM, keyPEM := certKey.CurrentCertKeyContent() cert, err := tls.X509KeyPair(certPEM, keyPEM) require.NoError(t, err) //nolint:staticcheck // since we're not using .Subjects() to access the system pool return pool.Subjects(), []tls.Certificate{cert} }, }, { name: "unset the CA", f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { ca.UnsetCertKeyContent() certPEM, keyPEM := certKey.CurrentCertKeyContent() cert, err := tls.X509KeyPair(certPEM, keyPEM) require.NoError(t, err) return nil, []tls.Certificate{cert} }, }, { name: "unset the serving cert - still serves the old content", f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { pool := x509.NewCertPool() ok := pool.AppendCertsFromPEM(ca.CurrentCABundleContent()) require.True(t, ok, "should have valid non-empty CA bundle") certPEM, keyPEM := certKey.CurrentCertKeyContent() cert, err := tls.X509KeyPair(certPEM, keyPEM) require.NoError(t, err) certKey.UnsetCertKeyContent() //nolint:staticcheck // since we're not using .Subjects() to access the system pool return pool.Subjects(), []tls.Certificate{cert} }, }, { name: "change to a new CA", f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { // use unique names for all CAs to make sure the pool subjects are different newCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-ca"), time.Hour) require.NoError(t, err) caKey, err := newCA.PrivateKeyToPEM() require.NoError(t, err) err = ca.SetCertKeyContent(newCA.Bundle(), caKey) require.NoError(t, err) certPEM, keyPEM := certKey.CurrentCertKeyContent() cert, err := tls.X509KeyPair(certPEM, keyPEM) require.NoError(t, err) //nolint:staticcheck // since we're not using .Subjects() to access the system pool return newCA.Pool().Subjects(), []tls.Certificate{cert} }, }, { name: "change to new serving cert", f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { // use unique names for all CAs to make sure the pool subjects are different newCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-ca"), time.Hour) require.NoError(t, err) certPEM, keyPEM, err := newCA.IssueServerCertPEM(nil, []net.IP{net.ParseIP("127.0.0.2")}, time.Hour) require.NoError(t, err) err = certKey.SetCertKeyContent(certPEM, keyPEM) require.NoError(t, err) cert, err := tls.X509KeyPair(certPEM, keyPEM) require.NoError(t, err) pool := x509.NewCertPool() ok := pool.AppendCertsFromPEM(ca.CurrentCABundleContent()) require.True(t, ok, "should have valid non-empty CA bundle") //nolint:staticcheck // since we're not using .Subjects() to access the system pool return pool.Subjects(), []tls.Certificate{cert} }, }, { name: "change both CA and serving cert", f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) { // use unique names for all CAs to make sure the pool subjects are different newCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-ca"), time.Hour) require.NoError(t, err) certPEM, keyPEM, err := newCA.IssueServerCertPEM(nil, []net.IP{net.ParseIP("127.0.0.3")}, time.Hour) require.NoError(t, err) err = certKey.SetCertKeyContent(certPEM, keyPEM) require.NoError(t, err) cert, err := tls.X509KeyPair(certPEM, keyPEM) require.NoError(t, err) // use unique names for all CAs to make sure the pool subjects are different newOtherCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-other-ca"), time.Hour) require.NoError(t, err) caKey, err := newOtherCA.PrivateKeyToPEM() require.NoError(t, err) err = ca.SetCertKeyContent(newOtherCA.Bundle(), caKey) require.NoError(t, err) //nolint:staticcheck // since we're not using .Subjects() to access the system pool return newOtherCA.Pool().Subjects(), []tls.Certificate{cert} }, }, } for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() // use unique names for all CAs to make sure the pool subjects are different ca, err := certauthority.New(names.SimpleNameGenerator.GenerateName("ca"), time.Hour) require.NoError(t, err) caKey, err := ca.PrivateKeyToPEM() require.NoError(t, err) caContent := NewCA("ca") err = caContent.SetCertKeyContent(ca.Bundle(), caKey) require.NoError(t, err) cert, key, err := ca.IssueServerCertPEM(nil, []net.IP{net.ParseIP("127.0.0.1")}, time.Hour) require.NoError(t, err) certKeyContent := NewServingCert("cert-key") err = certKeyContent.SetCertKeyContent(cert, key) require.NoError(t, err) tlsConfig := ptls.Default(nil) tlsConfig.ClientAuth = tls.RequestClientCert dynamicCertificateController := dynamiccertificates.NewDynamicServingCertificateController( tlsConfig, caContent, certKeyContent, nil, // we do not care about SNI nil, // we do not care about events ) caContent.AddListener(dynamicCertificateController) certKeyContent.AddListener(dynamicCertificateController) err = dynamicCertificateController.RunOnce() require.NoError(t, err) stopCh := make(chan struct{}) defer close(stopCh) go dynamicCertificateController.Run(1, stopCh) tlsConfig.GetConfigForClient = dynamicCertificateController.GetConfigForClient wantClientCASubjects, wantCerts := tt.f(t, caContent, certKeyContent) var lastTLSConfig *tls.Config // it will take some time for the controller to catch up err = wait.PollUntilContextTimeout(context.Background(), time.Second, 30*time.Second, true, func(ctx context.Context) (bool, error) { actualTLSConfig, err := tlsConfig.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "force-standard-sni"}) if err != nil { return false, err } lastTLSConfig = actualTLSConfig return reflect.DeepEqual(wantClientCASubjects, poolSubjects(actualTLSConfig.ClientCAs)) && reflect.DeepEqual(wantCerts, actualTLSConfig.Certificates), nil }) if err != nil && lastTLSConfig != nil { // for debugging failures t.Log("diff between client CAs:\n", cmp.Diff( testlib.Sdump(wantClientCASubjects), testlib.Sdump(poolSubjects(lastTLSConfig.ClientCAs)), )) t.Log("diff between serving certs:\n", cmp.Diff( testlib.Sdump(wantCerts), testlib.Sdump(lastTLSConfig.Certificates), )) } require.NoError(t, err) }) } } func poolSubjects(pool *x509.CertPool) [][]byte { if pool == nil { return nil } //nolint:staticcheck // since we're not using .Subjects() to access the system pool return pool.Subjects() } func TestNewServingCert(t *testing.T) { got := NewServingCert("") ok1 := assert.Implements(fakeT{}, (*Private)(nil), got) ok2 := assert.Implements(fakeT{}, (*Public)(nil), got) ok3 := assert.Implements(fakeT{}, (*Provider)(nil), got) require.True(t, ok1, "NewServingCert must implement Private") require.False(t, ok2, "NewServingCert must not implement Public") require.False(t, ok3, "NewServingCert must not implement Provider") } type fakeT struct{} func (fakeT) Errorf(string, ...interface{}) {}