108 lines
2.3 KiB
Go
108 lines
2.3 KiB
Go
// Copyright 2023 the Pinniped contributors. All Rights Reserved.
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
package impersonator
|
|
|
|
import (
|
|
"errors"
|
|
"net/http"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/oauth2"
|
|
|
|
"go.pinniped.dev/internal/tokenclient"
|
|
)
|
|
|
|
func TestWrappedRoundTripper(t *testing.T) {
|
|
var base = new(oauth2.Transport)
|
|
|
|
roundTripper := authorizationRoundTripper{
|
|
base: base,
|
|
}
|
|
|
|
require.Equal(t, base, roundTripper.WrappedRoundTripper())
|
|
}
|
|
|
|
type fakeRoundTripper struct {
|
|
request *http.Request
|
|
response *http.Response
|
|
err error
|
|
}
|
|
|
|
func (t *fakeRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
|
|
t.request = request
|
|
return t.response, t.err
|
|
}
|
|
|
|
var _ http.RoundTripper = (*fakeRoundTripper)(nil)
|
|
|
|
type fakeCache struct {
|
|
token string
|
|
}
|
|
|
|
func (c *fakeCache) Get() string {
|
|
return c.token
|
|
}
|
|
|
|
var _ tokenclient.ExpiringSingletonTokenCacheGet = (*fakeCache)(nil)
|
|
|
|
func TestRoundTrip(t *testing.T) {
|
|
fakeResponse := new(http.Response)
|
|
for _, tt := range []struct {
|
|
name string
|
|
token string
|
|
baseResponse *http.Response
|
|
baseError string
|
|
wantResponse *http.Response
|
|
wantError string
|
|
}{
|
|
{
|
|
name: "happy path",
|
|
token: "token",
|
|
baseResponse: fakeResponse,
|
|
baseError: "error from base",
|
|
wantResponse: fakeResponse,
|
|
wantError: "error from base",
|
|
},
|
|
{
|
|
name: "no token available",
|
|
token: "", // since the cache always returns a non-pointer string, this indicates empty
|
|
wantError: "no token available",
|
|
},
|
|
} {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
base := &fakeRoundTripper{
|
|
response: new(http.Response),
|
|
err: errors.New(tt.baseError),
|
|
}
|
|
|
|
cache := &fakeCache{
|
|
token: tt.token,
|
|
}
|
|
|
|
roundTripper := &authorizationRoundTripper{
|
|
cache: cache,
|
|
base: base,
|
|
}
|
|
|
|
//nolint:noctx // this is test code
|
|
request, err := http.NewRequest("GET", "https://example.com", nil)
|
|
require.NoError(t, err)
|
|
defer request.Body.Close()
|
|
|
|
response, err := roundTripper.RoundTrip(request)
|
|
require.Equal(t, tt.wantResponse, response)
|
|
require.ErrorContains(t, err, tt.wantError)
|
|
defer response.Body.Close()
|
|
|
|
if tt.token != "" {
|
|
require.Equal(t, "Bearer "+tt.token, base.request.Header.Get("Authorization"))
|
|
} else {
|
|
require.Empty(t, base.request)
|
|
}
|
|
})
|
|
}
|
|
}
|