108 lines
2.3 KiB
Go

// Copyright 2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package impersonator
import (
"errors"
"net/http"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"go.pinniped.dev/internal/tokenclient"
)
func TestWrappedRoundTripper(t *testing.T) {
var base = new(oauth2.Transport)
roundTripper := authorizationRoundTripper{
base: base,
}
require.Equal(t, base, roundTripper.WrappedRoundTripper())
}
type fakeRoundTripper struct {
request *http.Request
response *http.Response
err error
}
func (t *fakeRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
t.request = request
return t.response, t.err
}
var _ http.RoundTripper = (*fakeRoundTripper)(nil)
type fakeCache struct {
token string
}
func (c *fakeCache) Get() string {
return c.token
}
var _ tokenclient.ExpiringSingletonTokenCacheGet = (*fakeCache)(nil)
func TestRoundTrip(t *testing.T) {
fakeResponse := new(http.Response)
for _, tt := range []struct {
name string
token string
baseResponse *http.Response
baseError string
wantResponse *http.Response
wantError string
}{
{
name: "happy path",
token: "token",
baseResponse: fakeResponse,
baseError: "error from base",
wantResponse: fakeResponse,
wantError: "error from base",
},
{
name: "no token available",
token: "", // since the cache always returns a non-pointer string, this indicates empty
wantError: "no token available",
},
} {
tt := tt
t.Run(tt.name, func(t *testing.T) {
base := &fakeRoundTripper{
response: new(http.Response),
err: errors.New(tt.baseError),
}
cache := &fakeCache{
token: tt.token,
}
roundTripper := &authorizationRoundTripper{
cache: cache,
base: base,
}
//nolint:noctx // this is test code
request, err := http.NewRequest("GET", "https://example.com", nil)
require.NoError(t, err)
defer request.Body.Close()
response, err := roundTripper.RoundTrip(request)
require.Equal(t, tt.wantResponse, response)
require.ErrorContains(t, err, tt.wantError)
defer response.Body.Close()
if tt.token != "" {
require.Equal(t, "Bearer "+tt.token, base.request.Header.Get("Authorization"))
} else {
require.Empty(t, base.request)
}
})
}
}