Use an interface instead of a concrete type for UpstreamOIDCIdentityProvider

Because we want it to implement an AuthcodeExchanger interface and
do it in a way that will be more unit test-friendly than the underlying
library that we intend to use inside its implementation.
This commit is contained in:
Ryan Richard 2020-11-18 13:38:13 -08:00
parent 97552aec5f
commit 227fbd63aa
10 changed files with 220 additions and 65 deletions

View File

@ -62,7 +62,7 @@ const (
// IDPCache is a thread safe cache that holds a list of validated upstream OIDC IDP configurations. // IDPCache is a thread safe cache that holds a list of validated upstream OIDC IDP configurations.
type IDPCache interface { type IDPCache interface {
SetIDPList([]provider.UpstreamOIDCIdentityProvider) SetIDPList([]provider.UpstreamOIDCIdentityProviderI)
} }
// lruValidatorCache caches the *oidc.Provider associated with a particular issuer/TLS configuration. // lruValidatorCache caches the *oidc.Provider associated with a particular issuer/TLS configuration.
@ -132,13 +132,13 @@ func (c *controller) Sync(ctx controllerlib.Context) error {
} }
requeue := false requeue := false
validatedUpstreams := make([]provider.UpstreamOIDCIdentityProvider, 0, len(actualUpstreams)) validatedUpstreams := make([]provider.UpstreamOIDCIdentityProviderI, 0, len(actualUpstreams))
for _, upstream := range actualUpstreams { for _, upstream := range actualUpstreams {
valid := c.validateUpstream(ctx, upstream) valid := c.validateUpstream(ctx, upstream)
if valid == nil { if valid == nil {
requeue = true requeue = true
} else { } else {
validatedUpstreams = append(validatedUpstreams, *valid) validatedUpstreams = append(validatedUpstreams, provider.UpstreamOIDCIdentityProviderI(valid))
} }
} }
c.cache.SetIDPList(validatedUpstreams) c.cache.SetIDPList(validatedUpstreams)
@ -258,6 +258,8 @@ func (c *controller) validateIssuer(ctx context.Context, upstream *v1alpha1.Upst
c.validatorCache.putProvider(&upstream.Spec, discoveredProvider) c.validatorCache.putProvider(&upstream.Spec, discoveredProvider)
} }
// TODO also parse the token endpoint from the discovery info and put it onto the `result`
// Parse out and validate the discovered authorize endpoint. // Parse out and validate the discovered authorize endpoint.
authURL, err := url.Parse(discoveredProvider.Endpoint().AuthURL) authURL, err := url.Parse(discoveredProvider.Endpoint().AuthURL)
if err != nil { if err != nil {

View File

@ -527,7 +527,9 @@ func TestController(t *testing.T) {
kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0) kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0)
testLog := testlogger.New(t) testLog := testlogger.New(t)
cache := provider.NewDynamicUpstreamIDPProvider() cache := provider.NewDynamicUpstreamIDPProvider()
cache.SetIDPList([]provider.UpstreamOIDCIdentityProvider{{Name: "initial-entry"}}) initialProviderList := make([]provider.UpstreamOIDCIdentityProviderI, 1)
initialProviderList[0] = &provider.UpstreamOIDCIdentityProvider{Name: "initial-entry"}
cache.SetIDPList(initialProviderList)
controller := New( controller := New(
cache, cache,
@ -551,7 +553,13 @@ func TestController(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
require.Equal(t, strings.Join(tt.wantLogs, "\n"), strings.Join(testLog.Lines(), "\n")) require.Equal(t, strings.Join(tt.wantLogs, "\n"), strings.Join(testLog.Lines(), "\n"))
require.ElementsMatch(t, tt.wantResultingCache, cache.GetIDPList())
actualIDPList := cache.GetIDPList()
require.Equal(t, len(tt.wantResultingCache), len(actualIDPList))
for i := range actualIDPList {
actualIDP := actualIDPList[i].(*provider.UpstreamOIDCIdentityProvider)
require.Equal(t, tt.wantResultingCache[i], *actualIDP)
}
actualUpstreams, err := fakePinnipedClient.IDPV1alpha1().UpstreamOIDCProviders(testNamespace).List(ctx, metav1.ListOptions{}) actualUpstreams, err := fakePinnipedClient.IDPV1alpha1().UpstreamOIDCProviders(testNamespace).List(ctx, metav1.ListOptions{})
require.NoError(t, err) require.NoError(t, err)

View File

@ -88,12 +88,12 @@ func NewHandler(
} }
upstreamOAuthConfig := oauth2.Config{ upstreamOAuthConfig := oauth2.Config{
ClientID: upstreamIDP.ClientID, ClientID: upstreamIDP.GetClientID(),
Endpoint: oauth2.Endpoint{ Endpoint: oauth2.Endpoint{
AuthURL: upstreamIDP.AuthorizationURL.String(), AuthURL: upstreamIDP.GetAuthorizationURL().String(),
}, },
RedirectURL: fmt.Sprintf("%s/callback/%s", issuer, upstreamIDP.Name), RedirectURL: fmt.Sprintf("%s/callback/%s", issuer, upstreamIDP.GetName()),
Scopes: upstreamIDP.Scopes, Scopes: upstreamIDP.GetScopes(),
} }
encodedStateParamValue, err := upstreamStateParam(authorizeRequester, nonceValue, csrfValue, pkceValue, upstreamStateEncoder) encodedStateParamValue, err := upstreamStateParam(authorizeRequester, nonceValue, csrfValue, pkceValue, upstreamStateEncoder)
@ -150,7 +150,7 @@ func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) {
} }
} }
func chooseUpstreamIDP(idpListGetter oidc.IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) { func chooseUpstreamIDP(idpListGetter oidc.IDPListGetter) (provider.UpstreamOIDCIdentityProviderI, error) {
allUpstreamIDPs := idpListGetter.GetIDPList() allUpstreamIDPs := idpListGetter.GetIDPList()
if len(allUpstreamIDPs) == 0 { if len(allUpstreamIDPs) == 0 {
return nil, httperr.New( return nil, httperr.New(
@ -163,7 +163,7 @@ func chooseUpstreamIDP(idpListGetter oidc.IDPListGetter) (*provider.UpstreamOIDC
"Too many upstream providers are configured (support for multiple upstreams is not yet implemented)", "Too many upstream providers are configured (support for multiple upstreams is not yet implemented)",
) )
} }
return &allUpstreamIDPs[0], nil return allUpstreamIDPs[0], nil
} }
func generateValues( func generateValues(

View File

@ -113,7 +113,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth") upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth")
require.NoError(t, err) require.NoError(t, err)
upstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{ upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{
Name: "some-idp", Name: "some-idp",
ClientID: "some-client-id", ClientID: "some-client-id",
AuthorizationURL: *upstreamAuthURL, AuthorizationURL: *upstreamAuthURL,
@ -122,7 +122,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
issuer := "https://my-issuer.com/some-path" issuer := "https://my-issuer.com/some-path"
// Configure fosite the same way that the production code would, except use in-memory storage. // Configure fosite the same way that the production code would, using NullStorage to turn off storage.
oauthStore := oidc.NullStorage{} oauthStore := oidc.NullStorage{}
hmacSecret := []byte("some secret - must have at least 32 bytes") hmacSecret := []byte("some secret - must have at least 32 bytes")
require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes")
@ -771,13 +771,13 @@ func TestAuthorizationEndpoint(t *testing.T) {
runOneTestCase(t, test, subject) runOneTestCase(t, test, subject)
// Call the setter to change the upstream IDP settings. // Call the setter to change the upstream IDP settings.
newProviderSettings := provider.UpstreamOIDCIdentityProvider{ newProviderSettings := testutil.TestUpstreamOIDCIdentityProvider{
Name: "some-other-idp", Name: "some-other-idp",
ClientID: "some-other-client-id", ClientID: "some-other-client-id",
AuthorizationURL: *upstreamAuthURL, AuthorizationURL: *upstreamAuthURL,
Scopes: []string{"other-scope1", "other-scope2"}, Scopes: []string{"other-scope1", "other-scope2"},
} }
test.idpListGetter.SetIDPList([]provider.UpstreamOIDCIdentityProvider{newProviderSettings}) test.idpListGetter.SetIDPList([]provider.UpstreamOIDCIdentityProviderI{provider.UpstreamOIDCIdentityProviderI(&newProviderSettings)})
// Update the expectations of the test case to match the new upstream IDP settings. // Update the expectations of the test case to match the new upstream IDP settings.
test.wantLocationHeader = urlWithQuery(upstreamAuthURL.String(), test.wantLocationHeader = urlWithQuery(upstreamAuthURL.String(),

View File

@ -10,6 +10,8 @@ import (
"net/url" "net/url"
"path" "path"
"github.com/ory/fosite"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/csrftoken"
@ -17,10 +19,7 @@ import (
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
) )
func NewHandler( func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provider, stateDecoder, cookieDecoder oidc.Decoder) http.Handler {
idpListGetter oidc.IDPListGetter,
stateDecoder, cookieDecoder oidc.Decoder,
) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
state, err := validateRequest(r, stateDecoder, cookieDecoder) state, err := validateRequest(r, stateDecoder, cookieDecoder)
if err != nil { if err != nil {
@ -84,10 +83,10 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder)
return state, nil return state, nil
} }
func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *provider.UpstreamOIDCIdentityProvider { func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *provider.UpstreamOIDCIdentityProviderI {
_, lastPathComponent := path.Split(r.URL.Path) _, lastPathComponent := path.Split(r.URL.Path)
for _, p := range idpListGetter.GetIDPList() { for _, p := range idpListGetter.GetIDPList() {
if p.Name == lastPathComponent { if p.GetName() == lastPathComponent {
return &p return &p
} }
} }

View File

@ -13,8 +13,11 @@ import (
"testing" "testing"
"github.com/gorilla/securecookie" "github.com/gorilla/securecookie"
"github.com/ory/fosite"
"github.com/ory/fosite/storage"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
) )
@ -28,23 +31,41 @@ func TestCallbackEndpoint(t *testing.T) {
downstreamRedirectURI = "http://127.0.0.1/callback" downstreamRedirectURI = "http://127.0.0.1/callback"
) )
upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth") // TODO use a fosite memory store and pass in a fostite oauthHelper
require.NoError(t, err) // TODO write a test double for UpstreamOIDCIdentityProviderI ID token with a claim called "the-user-claim" and put a username as the value of that claim
otherUpstreamAuthURL, err := url.Parse("https://some-other-upstream-idp:8443/auth") // TODO assert that after the callback request, the fosite storage has 1 authcode key saved,
require.NoError(t, err) // and it is the same key that was returned in the redirect,
// and the value in storage includes the username in the fosite session
// TODO do the same thing with the groups list (store it in the fosite session as JWT claim)
// TODO test for when UpstreamOIDCIdentityProviderI authcode exchange fails
// TODO wire in the callback endpoint into the oidc manager request router
// TODO update the upstream watcher controller to also populate the new fields
// TODO update the integration test
// TODO DO NOT store the upstream tokens (or maybe just the refresh token) for this story. In a future story, we can store them/it in some other storage interface indexed by the same authcode hash that fosite used for storage.
// TODO grab the upstream config name from the state param instead of the URL path
upstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{ // Configure fosite the same way that the production code would, except use in-memory storage.
Name: happyUpstreamIDPName, oauthStore := &storage.MemoryStore{
ClientID: "some-client-id", Clients: map[string]fosite.Client{oidc.PinnipedCLIOIDCClient().ID: oidc.PinnipedCLIOIDCClient()},
AuthorizationURL: *upstreamAuthURL, AuthorizeCodes: map[string]storage.StoreAuthorizeCode{},
Scopes: []string{"scope1", "scope2"}, PKCES: map[string]fosite.Requester{},
IDSessions: map[string]fosite.Requester{},
}
hmacSecret := []byte("some secret - must have at least 32 bytes")
require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes")
oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret)
upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{
Name: happyUpstreamIDPName,
ClientID: "some-client-id",
UsernameClaim: "the-user-claim",
Scopes: []string{"scope1", "scope2"},
} }
otherUpstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{ otherUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{
Name: "other-upstream-idp-name", Name: "other-upstream-idp-name",
ClientID: "other-some-client-id", ClientID: "other-some-client-id",
AuthorizationURL: *otherUpstreamAuthURL, Scopes: []string{"other-scope1", "other-scope2"},
Scopes: []string{"other-scope1", "other-scope2"},
} }
var stateEncoderHashKey = []byte("fake-hash-secret") var stateEncoderHashKey = []byte("fake-hash-secret")
@ -61,7 +82,7 @@ func TestCallbackEndpoint(t *testing.T) {
happyDownstreamState := "some-downstream-state" happyDownstreamState := "some-downstream-state"
happyOrignalRequestParams := url.Values{ happyOriginalRequestParams := url.Values{
"response_type": []string{"code"}, "response_type": []string{"code"},
"scope": []string{"openid profile email"}, "scope": []string{"openid profile email"},
"client_id": []string{"pinniped-cli"}, "client_id": []string{"pinniped-cli"},
@ -77,7 +98,7 @@ func TestCallbackEndpoint(t *testing.T) {
happyState, err := happyStateCodec.Encode("s", happyState, err := happyStateCodec.Encode("s",
testutil.ExpectedUpstreamStateParamFormat{ testutil.ExpectedUpstreamStateParamFormat{
P: happyOrignalRequestParams, P: happyOriginalRequestParams,
N: happyNonce, N: happyNonce,
C: happyCSRF, C: happyCSRF,
K: happyPKCE, K: happyPKCE,
@ -88,7 +109,7 @@ func TestCallbackEndpoint(t *testing.T) {
wrongCSRFValueState, err := happyStateCodec.Encode("s", wrongCSRFValueState, err := happyStateCodec.Encode("s",
testutil.ExpectedUpstreamStateParamFormat{ testutil.ExpectedUpstreamStateParamFormat{
P: happyOrignalRequestParams, P: happyOriginalRequestParams,
N: happyNonce, N: happyNonce,
C: "wrong-csrf-value", C: "wrong-csrf-value",
K: happyPKCE, K: happyPKCE,
@ -99,7 +120,7 @@ func TestCallbackEndpoint(t *testing.T) {
wrongVersionState, err := happyStateCodec.Encode("s", wrongVersionState, err := happyStateCodec.Encode("s",
testutil.ExpectedUpstreamStateParamFormat{ testutil.ExpectedUpstreamStateParamFormat{
P: happyOrignalRequestParams, P: happyOriginalRequestParams,
N: happyNonce, N: happyNonce,
C: happyCSRF, C: happyCSRF,
K: happyPKCE, K: happyPKCE,
@ -260,7 +281,7 @@ func TestCallbackEndpoint(t *testing.T) {
for _, test := range tests { for _, test := range tests {
test := test test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
subject := NewHandler(test.idpListGetter, happyStateCodec, happyCookieCodec) subject := NewHandler(test.idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec)
req := httptest.NewRequest(test.method, test.path, nil) req := httptest.NewRequest(test.method, test.path, nil)
if test.csrfCookie != "" { if test.csrfCookie != "" {
req.Header.Set("Cookie", test.csrfCookie) req.Header.Set("Cookie", test.csrfCookie)
@ -285,7 +306,7 @@ func TestCallbackEndpoint(t *testing.T) {
capturedAuthCode := submatches[1] capturedAuthCode := submatches[1]
_ = capturedAuthCode _ = capturedAuthCode
// Assert capturedAuthCode storage stuff... // TODO Assert capturedAuthCode storage stuff...
// Assert that body contains anchor tag with redirect location. // Assert that body contains anchor tag with redirect location.
anchorTagWithLocationHref := fmt.Sprintf("<a href=\"%s\">Found</a>.\n\n", html.EscapeString(actualLocation)) anchorTagWithLocationHref := fmt.Sprintf("<a href=\"%s\">Found</a>.\n\n", html.EscapeString(actualLocation))

View File

@ -106,5 +106,5 @@ func FositeOauth2Helper(oauthStore interface{}, hmacSecretOfLengthAtLeast32 []by
} }
type IDPListGetter interface { type IDPListGetter interface {
GetIDPList() []provider.UpstreamOIDCIdentityProvider GetIDPList() []provider.UpstreamOIDCIdentityProviderI
} }

View File

@ -4,48 +4,114 @@
package provider package provider
import ( import (
"context"
"net/url" "net/url"
"sync" "sync"
"go.pinniped.dev/internal/oidcclient"
"go.pinniped.dev/internal/oidcclient/nonce"
"go.pinniped.dev/internal/oidcclient/pkce"
) )
type UpstreamOIDCIdentityProvider struct { type UpstreamOIDCIdentityProviderI interface {
// A name for this upstream provider, which will be used as a component of the path for the callback endpoint // A name for this upstream provider, which will be used as a component of the path for the callback endpoint
// hosted by the Supervisor. // hosted by the Supervisor.
Name string GetName() string
// The Oauth client ID registered with the upstream provider to be used in the authorization flow. // The Oauth client ID registered with the upstream provider to be used in the authorization code flow.
ClientID string GetClientID() string
// The Authorization Endpoint fetched from discovery. // The Authorization Endpoint fetched from discovery.
AuthorizationURL url.URL GetAuthorizationURL() *url.URL
// Scopes to request in authorization flow. // Scopes to request in authorization flow.
Scopes []string GetScopes() []string
// ID Token username claim name. May return empty string, in which case we will use some reasonable defaults.
GetUsernameClaim() string
// ID Token groups claim name. May return empty string, in which case we won't try to read groups from the upstream provider.
GetGroupsClaim() string
AuthcodeExchanger
}
// Performs upstream OIDC authorization code exchange and token validation.
// Returns the validated raw tokens as well as the parsed claims of the ID token.
type AuthcodeExchanger interface {
ExchangeAuthcodeAndValidateTokens(
ctx context.Context,
authcode string,
pkceCodeVerifier pkce.Code,
expectedIDTokenNonce nonce.Nonce,
) (tokens oidcclient.Token, parsedIDTokenClaims map[string]interface{}, err error)
}
type UpstreamOIDCIdentityProvider struct {
Name string
ClientID string
AuthorizationURL url.URL
UsernameClaim string
GroupsClaim string
Scopes []string
}
func (u *UpstreamOIDCIdentityProvider) GetName() string {
return u.Name
}
func (u *UpstreamOIDCIdentityProvider) GetClientID() string {
return u.ClientID
}
func (u *UpstreamOIDCIdentityProvider) GetAuthorizationURL() *url.URL {
return &u.AuthorizationURL
}
func (u *UpstreamOIDCIdentityProvider) GetScopes() []string {
return u.Scopes
}
func (u *UpstreamOIDCIdentityProvider) GetUsernameClaim() string {
return u.UsernameClaim
}
func (u *UpstreamOIDCIdentityProvider) GetGroupsClaim() string {
return u.GroupsClaim
}
func (u *UpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens(
ctx context.Context,
authcode string,
pkceCodeVerifier pkce.Code,
expectedIDTokenNonce nonce.Nonce,
) (oidcclient.Token, map[string]interface{}, error) {
panic("TODO implement me") // TODO
} }
type DynamicUpstreamIDPProvider interface { type DynamicUpstreamIDPProvider interface {
SetIDPList(oidcIDPs []UpstreamOIDCIdentityProvider) SetIDPList(oidcIDPs []UpstreamOIDCIdentityProviderI)
GetIDPList() []UpstreamOIDCIdentityProvider GetIDPList() []UpstreamOIDCIdentityProviderI
} }
type dynamicUpstreamIDPProvider struct { type dynamicUpstreamIDPProvider struct {
oidcProviders []UpstreamOIDCIdentityProvider oidcProviders []UpstreamOIDCIdentityProviderI
mutex sync.RWMutex mutex sync.RWMutex
} }
func NewDynamicUpstreamIDPProvider() DynamicUpstreamIDPProvider { func NewDynamicUpstreamIDPProvider() DynamicUpstreamIDPProvider {
return &dynamicUpstreamIDPProvider{ return &dynamicUpstreamIDPProvider{
oidcProviders: []UpstreamOIDCIdentityProvider{}, oidcProviders: []UpstreamOIDCIdentityProviderI{},
} }
} }
func (p *dynamicUpstreamIDPProvider) SetIDPList(oidcIDPs []UpstreamOIDCIdentityProvider) { func (p *dynamicUpstreamIDPProvider) SetIDPList(oidcIDPs []UpstreamOIDCIdentityProviderI) {
p.mutex.Lock() // acquire a write lock p.mutex.Lock() // acquire a write lock
defer p.mutex.Unlock() defer p.mutex.Unlock()
p.oidcProviders = oidcIDPs p.oidcProviders = oidcIDPs
} }
func (p *dynamicUpstreamIDPProvider) GetIDPList() []UpstreamOIDCIdentityProvider { func (p *dynamicUpstreamIDPProvider) GetIDPList() []UpstreamOIDCIdentityProviderI {
p.mutex.RLock() // acquire a read lock p.mutex.RLock() // acquire a read lock
defer p.mutex.RUnlock() defer p.mutex.RUnlock()
return p.oidcProviders return p.oidcProviders

View File

@ -12,6 +12,8 @@ import (
"strings" "strings"
"testing" "testing"
"go.pinniped.dev/internal/testutil"
"github.com/sclevine/spec" "github.com/sclevine/spec"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
@ -107,14 +109,11 @@ func TestManager(t *testing.T) {
parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL) parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL)
r.NoError(err) r.NoError(err)
idpListGetter := provider.NewDynamicUpstreamIDPProvider() idpListGetter := testutil.NewIDPListGetter(testutil.TestUpstreamOIDCIdentityProvider{
idpListGetter.SetIDPList([]provider.UpstreamOIDCIdentityProvider{ Name: "test-idp",
{ ClientID: "test-client-id",
Name: "test-idp", AuthorizationURL: *parsedUpstreamIDPAuthorizationURL,
ClientID: "test-client-id", Scopes: []string{"test-scope"},
AuthorizationURL: *parsedUpstreamIDPAuthorizationURL,
Scopes: []string{"test-scope"},
},
}) })
subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter) subject = NewManager(nextHandler, dynamicJWKSProvider, idpListGetter)

View File

@ -3,13 +3,73 @@
package testutil package testutil
import "go.pinniped.dev/internal/oidc/provider" import (
"context"
"net/url"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidcclient"
"go.pinniped.dev/internal/oidcclient/nonce"
"go.pinniped.dev/internal/oidcclient/pkce"
)
// Test helpers for the OIDC package. // Test helpers for the OIDC package.
func NewIDPListGetter(upstreamOIDCIdentityProviders ...provider.UpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider { type TestUpstreamOIDCIdentityProvider struct {
Name string
ClientID string
AuthorizationURL url.URL
UsernameClaim string
GroupsClaim string
Scopes []string
ExchangeAuthcodeAndValidateTokensFunc func(
ctx context.Context,
authcode string,
pkceCodeVerifier pkce.Code,
expectedIDTokenNonce nonce.Nonce,
) (oidcclient.Token, map[string]interface{}, error)
}
func (u *TestUpstreamOIDCIdentityProvider) GetName() string {
return u.Name
}
func (u *TestUpstreamOIDCIdentityProvider) GetClientID() string {
return u.ClientID
}
func (u *TestUpstreamOIDCIdentityProvider) GetAuthorizationURL() *url.URL {
return &u.AuthorizationURL
}
func (u *TestUpstreamOIDCIdentityProvider) GetScopes() []string {
return u.Scopes
}
func (u *TestUpstreamOIDCIdentityProvider) GetUsernameClaim() string {
return u.UsernameClaim
}
func (u *TestUpstreamOIDCIdentityProvider) GetGroupsClaim() string {
return u.GroupsClaim
}
func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens(
ctx context.Context,
authcode string,
pkceCodeVerifier pkce.Code,
expectedIDTokenNonce nonce.Nonce,
) (oidcclient.Token, map[string]interface{}, error) {
return u.ExchangeAuthcodeAndValidateTokensFunc(ctx, authcode, pkceCodeVerifier, expectedIDTokenNonce)
}
func NewIDPListGetter(upstreamOIDCIdentityProviders ...TestUpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider {
idpProvider := provider.NewDynamicUpstreamIDPProvider() idpProvider := provider.NewDynamicUpstreamIDPProvider()
idpProvider.SetIDPList(upstreamOIDCIdentityProviders) upstreams := make([]provider.UpstreamOIDCIdentityProviderI, len(upstreamOIDCIdentityProviders))
for i := range upstreamOIDCIdentityProviders {
upstreams[i] = provider.UpstreamOIDCIdentityProviderI(&upstreamOIDCIdentityProviders[i])
}
idpProvider.SetIDPList(upstreams)
return idpProvider return idpProvider
} }