ContainerImage.Pinniped/internal/testutil/tlsserver/tlsserver.go

125 lines
3.2 KiB
Go
Raw Normal View History

// Copyright 2021-2022 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package tlsserver
import (
"context"
"crypto/tls"
"encoding/pem"
"fmt"
"net"
"net/http"
"net/http/httptest"
"reflect"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"k8s.io/apimachinery/pkg/util/httpstream"
"go.pinniped.dev/internal/crypto/ptls"
)
type ctxKey int
const (
mapKey ctxKey = iota + 1
helloKey
)
func TLSTestServer(t *testing.T, handler http.Handler, f func(*httptest.Server)) *httptest.Server {
t.Helper()
server := httptest.NewUnstartedServer(handler)
server.TLS = ptls.Default(nil) // mimic API server config
if f != nil {
f(server)
}
server.StartTLS()
t.Cleanup(server.Close)
return server
}
func TLSTestServerCA(server *httptest.Server) []byte {
return pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: server.Certificate().Raw,
})
}
func RecordTLSHello(server *httptest.Server) {
server.Config.ConnContext = func(ctx context.Context, _ net.Conn) context.Context {
return context.WithValue(ctx, mapKey, &sync.Map{})
}
server.TLS.GetConfigForClient = func(info *tls.ClientHelloInfo) (*tls.Config, error) {
m, ok := getCtxMap(info.Context())
if !ok {
return nil, fmt.Errorf("could not find ctx map")
}
if actual, loaded := m.LoadOrStore(helloKey, info); loaded && !reflect.DeepEqual(info, actual) {
return nil, fmt.Errorf("different client hello seen")
}
return nil, nil
}
}
func AssertTLS(t *testing.T, r *http.Request, tlsConfigFunc ptls.ConfigFunc) {
t.Helper()
tlsConfig := tlsConfigFunc(nil)
AssertTLSConfig(t, r, tlsConfig)
}
func AssertTLSConfig(t *testing.T, r *http.Request, tlsConfig *tls.Config) {
t.Helper()
m, ok := getCtxMap(r.Context())
require.True(t, ok)
h, ok := m.Load(helloKey)
require.True(t, ok)
info, ok := h.(*tls.ClientHelloInfo)
require.True(t, ok)
supportedVersions := []uint16{tlsConfig.MinVersion}
ciphers := tlsConfig.CipherSuites
if secureTLSConfig := ptls.Secure(nil); tlsConfig.MinVersion != secureTLSConfig.MinVersion {
supportedVersions = append([]uint16{secureTLSConfig.MinVersion}, supportedVersions...)
ciphers = append(ciphers, secureTLSConfig.CipherSuites...)
}
protos := tlsConfig.NextProtos
if httpstream.IsUpgradeRequest(r) {
protos = tlsConfig.NextProtos[1:]
}
// use assert instead of require to not break the http.Handler with a panic
ok1 := assert.Equal(t, supportedVersions, info.SupportedVersions)
ok2 := assert.Equal(t, cipherSuiteIDsToStrings(ciphers), cipherSuiteIDsToStrings(info.CipherSuites))
ok3 := assert.Equal(t, protos, info.SupportedProtos)
if all := ok1 && ok2 && ok3; !all {
t.Errorf("insecure TLS detected for %q %q %q upgrade=%v supportedVersions=%v ciphers=%v protos=%v",
r.Proto, r.Method, r.URL.String(), httpstream.IsUpgradeRequest(r), ok1, ok2, ok3)
}
}
func cipherSuiteIDsToStrings(ids []uint16) []string {
cipherSuites := make([]string, 0, len(ids))
for _, id := range ids {
cipherSuites = append(cipherSuites, tls.CipherSuiteName(id))
}
return cipherSuites
}
func getCtxMap(ctx context.Context) (*sync.Map, bool) {
m, ok := ctx.Value(mapKey).(*sync.Map)
return m, ok
}