ContainerImage.Pinniped/internal/dynamiccert/provider_test.go

244 lines
7.7 KiB
Go
Raw Normal View History

// Copyright 2021 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package dynamiccert
import (
"crypto/tls"
"crypto/x509"
"net"
"reflect"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/apiserver/pkg/server/dynamiccertificates"
"k8s.io/apiserver/pkg/storage/names"
"go.pinniped.dev/internal/certauthority"
"go.pinniped.dev/test/testlib"
)
func TestProviderWithDynamicServingCertificateController(t *testing.T) {
t.Parallel()
tests := []struct {
name string
f func(t *testing.T, ca Provider, certKey Private) (wantClientCASubjects [][]byte, wantCerts []tls.Certificate)
}{
{
name: "no-op leave everything alone",
f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) {
pool := x509.NewCertPool()
ok := pool.AppendCertsFromPEM(ca.CurrentCABundleContent())
require.True(t, ok, "should have valid non-empty CA bundle")
certPEM, keyPEM := certKey.CurrentCertKeyContent()
cert, err := tls.X509KeyPair(certPEM, keyPEM)
require.NoError(t, err)
return pool.Subjects(), []tls.Certificate{cert}
},
},
{
name: "unset the CA",
f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) {
ca.UnsetCertKeyContent()
certPEM, keyPEM := certKey.CurrentCertKeyContent()
cert, err := tls.X509KeyPair(certPEM, keyPEM)
require.NoError(t, err)
return nil, []tls.Certificate{cert}
},
},
{
name: "unset the serving cert - still serves the old content",
f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) {
pool := x509.NewCertPool()
ok := pool.AppendCertsFromPEM(ca.CurrentCABundleContent())
require.True(t, ok, "should have valid non-empty CA bundle")
certPEM, keyPEM := certKey.CurrentCertKeyContent()
cert, err := tls.X509KeyPair(certPEM, keyPEM)
require.NoError(t, err)
certKey.UnsetCertKeyContent()
return pool.Subjects(), []tls.Certificate{cert}
},
},
{
name: "change to a new CA",
f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) {
// use unique names for all CAs to make sure the pool subjects are different
newCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-ca"), time.Hour)
require.NoError(t, err)
caKey, err := newCA.PrivateKeyToPEM()
require.NoError(t, err)
err = ca.SetCertKeyContent(newCA.Bundle(), caKey)
require.NoError(t, err)
certPEM, keyPEM := certKey.CurrentCertKeyContent()
cert, err := tls.X509KeyPair(certPEM, keyPEM)
require.NoError(t, err)
return newCA.Pool().Subjects(), []tls.Certificate{cert}
},
},
{
name: "change to new serving cert",
f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) {
// use unique names for all CAs to make sure the pool subjects are different
newCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-ca"), time.Hour)
require.NoError(t, err)
certPEM, keyPEM, err := newCA.IssueServerCertPEM(nil, []net.IP{net.ParseIP("127.0.0.2")}, time.Hour)
require.NoError(t, err)
err = certKey.SetCertKeyContent(certPEM, keyPEM)
require.NoError(t, err)
cert, err := tls.X509KeyPair(certPEM, keyPEM)
require.NoError(t, err)
pool := x509.NewCertPool()
ok := pool.AppendCertsFromPEM(ca.CurrentCABundleContent())
require.True(t, ok, "should have valid non-empty CA bundle")
return pool.Subjects(), []tls.Certificate{cert}
},
},
{
name: "change both CA and serving cert",
f: func(t *testing.T, ca Provider, certKey Private) ([][]byte, []tls.Certificate) {
// use unique names for all CAs to make sure the pool subjects are different
newCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-ca"), time.Hour)
require.NoError(t, err)
certPEM, keyPEM, err := newCA.IssueServerCertPEM(nil, []net.IP{net.ParseIP("127.0.0.3")}, time.Hour)
require.NoError(t, err)
err = certKey.SetCertKeyContent(certPEM, keyPEM)
require.NoError(t, err)
cert, err := tls.X509KeyPair(certPEM, keyPEM)
require.NoError(t, err)
// use unique names for all CAs to make sure the pool subjects are different
newOtherCA, err := certauthority.New(names.SimpleNameGenerator.GenerateName("new-other-ca"), time.Hour)
require.NoError(t, err)
caKey, err := newOtherCA.PrivateKeyToPEM()
require.NoError(t, err)
err = ca.SetCertKeyContent(newOtherCA.Bundle(), caKey)
require.NoError(t, err)
return newOtherCA.Pool().Subjects(), []tls.Certificate{cert}
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// use unique names for all CAs to make sure the pool subjects are different
ca, err := certauthority.New(names.SimpleNameGenerator.GenerateName("ca"), time.Hour)
require.NoError(t, err)
caKey, err := ca.PrivateKeyToPEM()
require.NoError(t, err)
caContent := NewCA("ca")
err = caContent.SetCertKeyContent(ca.Bundle(), caKey)
require.NoError(t, err)
cert, key, err := ca.IssueServerCertPEM(nil, []net.IP{net.ParseIP("127.0.0.1")}, time.Hour)
require.NoError(t, err)
certKeyContent := NewServingCert("cert-key")
err = certKeyContent.SetCertKeyContent(cert, key)
require.NoError(t, err)
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
NextProtos: []string{"h2", "http/1.1"},
ClientAuth: tls.RequestClientCert,
}
dynamicCertificateController := dynamiccertificates.NewDynamicServingCertificateController(
tlsConfig,
caContent,
certKeyContent,
nil, // we do not care about SNI
nil, // we do not care about events
)
caContent.AddListener(dynamicCertificateController)
certKeyContent.AddListener(dynamicCertificateController)
err = dynamicCertificateController.RunOnce()
require.NoError(t, err)
stopCh := make(chan struct{})
defer close(stopCh)
go dynamicCertificateController.Run(1, stopCh)
tlsConfig.GetConfigForClient = dynamicCertificateController.GetConfigForClient
wantClientCASubjects, wantCerts := tt.f(t, caContent, certKeyContent)
var lastTLSConfig *tls.Config
// it will take some time for the controller to catch up
err = wait.PollImmediate(time.Second, 30*time.Second, func() (bool, error) {
actualTLSConfig, err := tlsConfig.GetConfigForClient(&tls.ClientHelloInfo{ServerName: "force-standard-sni"})
if err != nil {
return false, err
}
lastTLSConfig = actualTLSConfig
return reflect.DeepEqual(wantClientCASubjects, poolSubjects(actualTLSConfig.ClientCAs)) &&
reflect.DeepEqual(wantCerts, actualTLSConfig.Certificates), nil
})
if err != nil && lastTLSConfig != nil {
// for debugging failures
t.Log("diff between client CAs:\n", cmp.Diff(
testlib.Sdump(wantClientCASubjects),
testlib.Sdump(poolSubjects(lastTLSConfig.ClientCAs)),
))
t.Log("diff between serving certs:\n", cmp.Diff(
testlib.Sdump(wantCerts),
testlib.Sdump(lastTLSConfig.Certificates),
))
}
require.NoError(t, err)
})
}
}
func poolSubjects(pool *x509.CertPool) [][]byte {
if pool == nil {
return nil
}
return pool.Subjects()
}
func TestNewServingCert(t *testing.T) {
got := NewServingCert("")
ok1 := assert.Implements(fakeT{}, (*Private)(nil), got)
ok2 := assert.Implements(fakeT{}, (*Public)(nil), got)
ok3 := assert.Implements(fakeT{}, (*Provider)(nil), got)
require.True(t, ok1, "NewServingCert must implement Private")
require.False(t, ok2, "NewServingCert must not implement Public")
require.False(t, ok3, "NewServingCert must not implement Provider")
}
type fakeT struct{}
func (fakeT) Errorf(string, ...interface{}) {}