117 lines
2.5 KiB
Go
117 lines
2.5 KiB
Go
|
// Copyright 2021 the Pinniped contributors. All Rights Reserved.
|
||
|
// SPDX-License-Identifier: Apache-2.0
|
||
|
|
||
|
package phttp
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/tls"
|
||
|
"crypto/x509"
|
||
|
"net/http"
|
||
|
"testing"
|
||
|
|
||
|
"github.com/stretchr/testify/assert"
|
||
|
"github.com/stretchr/testify/require"
|
||
|
"k8s.io/apimachinery/pkg/util/net"
|
||
|
"k8s.io/client-go/util/cert"
|
||
|
|
||
|
"go.pinniped.dev/internal/crypto/ptls"
|
||
|
"go.pinniped.dev/internal/testutil/tlsserver"
|
||
|
)
|
||
|
|
||
|
// TestUnwrap ensures that the http.Client structs returned by this package contain
|
||
|
// a transport that can be fully unwrapped to get access to the underlying TLS config.
|
||
|
func TestUnwrap(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
f func(*x509.CertPool) *http.Client
|
||
|
}{
|
||
|
{
|
||
|
name: "default",
|
||
|
f: Default,
|
||
|
},
|
||
|
{
|
||
|
name: "secure",
|
||
|
f: Secure,
|
||
|
},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
tt := tt
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
p, err := x509.SystemCertPool()
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
c := tt.f(p)
|
||
|
|
||
|
tlsConfig, err := net.TLSClientConfig(c.Transport)
|
||
|
require.NoError(t, err)
|
||
|
require.NotNil(t, tlsConfig)
|
||
|
|
||
|
require.NotEmpty(t, tlsConfig.NextProtos)
|
||
|
require.GreaterOrEqual(t, tlsConfig.MinVersion, uint16(tls.VersionTLS12))
|
||
|
require.Equal(t, p, tlsConfig.RootCAs)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestClient(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
clientFunc func(*x509.CertPool) *http.Client
|
||
|
configFunc ptls.ConfigFunc
|
||
|
}{
|
||
|
{
|
||
|
name: "default",
|
||
|
clientFunc: Default,
|
||
|
configFunc: ptls.Default,
|
||
|
},
|
||
|
{
|
||
|
name: "secure",
|
||
|
clientFunc: Secure,
|
||
|
configFunc: ptls.Secure,
|
||
|
},
|
||
|
}
|
||
|
for _, tt := range tests {
|
||
|
tt := tt
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
t.Parallel()
|
||
|
|
||
|
var sawRequest bool
|
||
|
server := tlsserver.TLSTestServer(t, http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
||
|
tlsserver.AssertTLS(t, r, tt.configFunc)
|
||
|
assertUserAgent(t, r)
|
||
|
sawRequest = true
|
||
|
}), tlsserver.RecordTLSHello)
|
||
|
|
||
|
rootCAs, err := cert.NewPoolFromBytes(tlsserver.TLSTestServerCA(server))
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
c := tt.clientFunc(rootCAs)
|
||
|
|
||
|
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil)
|
||
|
require.NoError(t, err)
|
||
|
|
||
|
resp, err := c.Do(req)
|
||
|
require.NoError(t, err)
|
||
|
require.NoError(t, resp.Body.Close())
|
||
|
|
||
|
require.True(t, sawRequest)
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func assertUserAgent(t *testing.T, r *http.Request) {
|
||
|
t.Helper()
|
||
|
|
||
|
ua := r.Header.Get("user-agent")
|
||
|
|
||
|
// use assert instead of require to not break the http.Handler with a panic
|
||
|
assert.Contains(t, ua, ") kubernetes/")
|
||
|
}
|