// Copyright 2020-2023 the Pinniped contributors. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 package oidctestutil import ( "context" "crypto" "crypto/ecdsa" "fmt" "net/url" "regexp" "strings" "testing" "time" coreosoidc "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/securecookie" "github.com/ory/fosite" "github.com/stretchr/testify/require" "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/kubernetes/fake" v1 "k8s.io/client-go/kubernetes/typed/core/v1" kubetesting "k8s.io/client-go/testing" "k8s.io/utils/strings/slices" "go.pinniped.dev/internal/authenticators" "go.pinniped.dev/internal/crud" "go.pinniped.dev/internal/federationdomain/dynamicupstreamprovider" "go.pinniped.dev/internal/federationdomain/resolvedprovider" "go.pinniped.dev/internal/federationdomain/upstreamprovider" "go.pinniped.dev/internal/fositestorage/authorizationcode" "go.pinniped.dev/internal/fositestorage/openidconnect" "go.pinniped.dev/internal/fositestorage/pkce" "go.pinniped.dev/internal/fositestoragei" "go.pinniped.dev/internal/idtransform" "go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/testutil" "go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/oidctypes" oidcpkce "go.pinniped.dev/pkg/oidcclient/pkce" ) // Test helpers for the OIDC package. // ExchangeAuthcodeAndValidateTokenArgs is used to spy on calls to // TestUpstreamOIDCIdentityProvider.ExchangeAuthcodeAndValidateTokensFunc(). type ExchangeAuthcodeAndValidateTokenArgs struct { Ctx context.Context Authcode string PKCECodeVerifier oidcpkce.Code ExpectedIDTokenNonce nonce.Nonce RedirectURI string } // PasswordCredentialsGrantAndValidateTokensArgs is used to spy on calls to // TestUpstreamOIDCIdentityProvider.PasswordCredentialsGrantAndValidateTokensFunc(). type PasswordCredentialsGrantAndValidateTokensArgs struct { Ctx context.Context Username string Password string } // PerformRefreshArgs is used to spy on calls to // TestUpstreamOIDCIdentityProvider.PerformRefreshFunc(). type PerformRefreshArgs struct { Ctx context.Context RefreshToken string DN string ExpectedUsername string ExpectedSubject string } // RevokeTokenArgs is used to spy on calls to // TestUpstreamOIDCIdentityProvider.RevokeTokenArgsFunc(). type RevokeTokenArgs struct { Ctx context.Context Token string TokenType upstreamprovider.RevocableTokenType } // ValidateTokenAndMergeWithUserInfoArgs is used to spy on calls to // TestUpstreamOIDCIdentityProvider.ValidateTokenAndMergeWithUserInfoFunc(). type ValidateTokenAndMergeWithUserInfoArgs struct { Ctx context.Context Tok *oauth2.Token ExpectedIDTokenNonce nonce.Nonce RequireIDToken bool RequireUserInfo bool } type ValidateRefreshArgs struct { Ctx context.Context Tok *oauth2.Token StoredAttributes upstreamprovider.RefreshAttributes } func NewTestUpstreamLDAPIdentityProviderBuilder() *TestUpstreamLDAPIdentityProviderBuilder { return &TestUpstreamLDAPIdentityProviderBuilder{} } type TestUpstreamLDAPIdentityProviderBuilder struct { name string resourceUID types.UID url *url.URL authenticateFunc func(ctx context.Context, username, password string) (*authenticators.Response, bool, error) performRefreshCallCount int performRefreshArgs []*PerformRefreshArgs performRefreshErr error performRefreshGroups []string displayNameForFederationDomain string transformsForFederationDomain *idtransform.TransformationPipeline } func (t *TestUpstreamLDAPIdentityProviderBuilder) WithName(name string) *TestUpstreamLDAPIdentityProviderBuilder { t.name = name return t } func (t *TestUpstreamLDAPIdentityProviderBuilder) WithResourceUID(uid types.UID) *TestUpstreamLDAPIdentityProviderBuilder { t.resourceUID = uid return t } func (t *TestUpstreamLDAPIdentityProviderBuilder) WithURL(url *url.URL) *TestUpstreamLDAPIdentityProviderBuilder { t.url = url return t } func (t *TestUpstreamLDAPIdentityProviderBuilder) WithAuthenticateFunc(f func(ctx context.Context, username, password string) (*authenticators.Response, bool, error)) *TestUpstreamLDAPIdentityProviderBuilder { t.authenticateFunc = f return t } func (t *TestUpstreamLDAPIdentityProviderBuilder) WithPerformRefreshCallCount(count int) *TestUpstreamLDAPIdentityProviderBuilder { t.performRefreshCallCount = count return t } func (t *TestUpstreamLDAPIdentityProviderBuilder) WithPerformRefreshArgs(args []*PerformRefreshArgs) *TestUpstreamLDAPIdentityProviderBuilder { t.performRefreshArgs = args return t } func (t *TestUpstreamLDAPIdentityProviderBuilder) WithPerformRefreshErr(err error) *TestUpstreamLDAPIdentityProviderBuilder { t.performRefreshErr = err return t } func (t *TestUpstreamLDAPIdentityProviderBuilder) WithPerformRefreshGroups(groups []string) *TestUpstreamLDAPIdentityProviderBuilder { t.performRefreshGroups = groups return t } func (t *TestUpstreamLDAPIdentityProviderBuilder) WithDisplayNameForFederationDomain(displayName string) *TestUpstreamLDAPIdentityProviderBuilder { t.displayNameForFederationDomain = displayName return t } func (t *TestUpstreamLDAPIdentityProviderBuilder) WithTransformsForFederationDomain(transforms *idtransform.TransformationPipeline) *TestUpstreamLDAPIdentityProviderBuilder { t.transformsForFederationDomain = transforms return t } func (t *TestUpstreamLDAPIdentityProviderBuilder) Build() *TestUpstreamLDAPIdentityProvider { if t.displayNameForFederationDomain == "" { // default it to the CR name t.displayNameForFederationDomain = t.name } if t.transformsForFederationDomain == nil { // default to an empty pipeline t.transformsForFederationDomain = idtransform.NewTransformationPipeline() } return &TestUpstreamLDAPIdentityProvider{ Name: t.name, ResourceUID: t.resourceUID, URL: t.url, AuthenticateFunc: t.authenticateFunc, performRefreshCallCount: t.performRefreshCallCount, performRefreshArgs: t.performRefreshArgs, PerformRefreshErr: t.performRefreshErr, PerformRefreshGroups: t.performRefreshGroups, DisplayNameForFederationDomain: t.displayNameForFederationDomain, TransformsForFederationDomain: t.transformsForFederationDomain, } } type TestUpstreamLDAPIdentityProvider struct { Name string ResourceUID types.UID URL *url.URL AuthenticateFunc func(ctx context.Context, username, password string) (*authenticators.Response, bool, error) performRefreshCallCount int performRefreshArgs []*PerformRefreshArgs PerformRefreshErr error PerformRefreshGroups []string DisplayNameForFederationDomain string TransformsForFederationDomain *idtransform.TransformationPipeline } var _ upstreamprovider.UpstreamLDAPIdentityProviderI = &TestUpstreamLDAPIdentityProvider{} func (u *TestUpstreamLDAPIdentityProvider) GetResourceUID() types.UID { return u.ResourceUID } func (u *TestUpstreamLDAPIdentityProvider) GetName() string { return u.Name } func (u *TestUpstreamLDAPIdentityProvider) AuthenticateUser(ctx context.Context, username, password string, grantedScopes []string) (*authenticators.Response, bool, error) { return u.AuthenticateFunc(ctx, username, password) } func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL { return u.URL } func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes upstreamprovider.RefreshAttributes) ([]string, error) { if u.performRefreshArgs == nil { u.performRefreshArgs = make([]*PerformRefreshArgs, 0) } u.performRefreshCallCount++ u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{ Ctx: ctx, DN: storedRefreshAttributes.DN, ExpectedUsername: storedRefreshAttributes.Username, ExpectedSubject: storedRefreshAttributes.Subject, }) if u.PerformRefreshErr != nil { return nil, u.PerformRefreshErr } return u.PerformRefreshGroups, nil } func (u *TestUpstreamLDAPIdentityProvider) PerformRefreshCallCount() int { return u.performRefreshCallCount } func (u *TestUpstreamLDAPIdentityProvider) PerformRefreshArgs(call int) *PerformRefreshArgs { if u.performRefreshArgs == nil { u.performRefreshArgs = make([]*PerformRefreshArgs, 0) } return u.performRefreshArgs[call] } type TestUpstreamOIDCIdentityProvider struct { Name string ClientID string ResourceUID types.UID AuthorizationURL url.URL UserInfoURL bool RevocationURL *url.URL UsernameClaim string GroupsClaim string Scopes []string AdditionalAuthcodeParams map[string]string AdditionalClaimMappings map[string]string AllowPasswordGrant bool DisplayNameForFederationDomain string TransformsForFederationDomain *idtransform.TransformationPipeline ExchangeAuthcodeAndValidateTokensFunc func( ctx context.Context, authcode string, pkceCodeVerifier oidcpkce.Code, expectedIDTokenNonce nonce.Nonce, ) (*oidctypes.Token, error) PasswordCredentialsGrantAndValidateTokensFunc func( ctx context.Context, username string, password string, ) (*oidctypes.Token, error) PerformRefreshFunc func(ctx context.Context, refreshToken string) (*oauth2.Token, error) RevokeTokenFunc func(ctx context.Context, refreshToken string, tokenType upstreamprovider.RevocableTokenType) error ValidateTokenAndMergeWithUserInfoFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) exchangeAuthcodeAndValidateTokensCallCount int exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs passwordCredentialsGrantAndValidateTokensCallCount int passwordCredentialsGrantAndValidateTokensArgs []*PasswordCredentialsGrantAndValidateTokensArgs performRefreshCallCount int performRefreshArgs []*PerformRefreshArgs revokeTokenCallCount int revokeTokenArgs []*RevokeTokenArgs validateTokenAndMergeWithUserInfoCallCount int validateTokenAndMergeWithUserInfoArgs []*ValidateTokenAndMergeWithUserInfoArgs } var _ upstreamprovider.UpstreamOIDCIdentityProviderI = &TestUpstreamOIDCIdentityProvider{} func (u *TestUpstreamOIDCIdentityProvider) GetResourceUID() types.UID { return u.ResourceUID } func (u *TestUpstreamOIDCIdentityProvider) GetAdditionalAuthcodeParams() map[string]string { return u.AdditionalAuthcodeParams } func (u *TestUpstreamOIDCIdentityProvider) GetAdditionalClaimMappings() map[string]string { return u.AdditionalClaimMappings } func (u *TestUpstreamOIDCIdentityProvider) GetName() string { return u.Name } func (u *TestUpstreamOIDCIdentityProvider) GetClientID() string { return u.ClientID } func (u *TestUpstreamOIDCIdentityProvider) GetAuthorizationURL() *url.URL { return &u.AuthorizationURL } func (u *TestUpstreamOIDCIdentityProvider) HasUserInfoURL() bool { return u.UserInfoURL } func (u *TestUpstreamOIDCIdentityProvider) GetRevocationURL() *url.URL { return u.RevocationURL } func (u *TestUpstreamOIDCIdentityProvider) GetScopes() []string { return u.Scopes } func (u *TestUpstreamOIDCIdentityProvider) GetUsernameClaim() string { return u.UsernameClaim } func (u *TestUpstreamOIDCIdentityProvider) GetGroupsClaim() string { return u.GroupsClaim } func (u *TestUpstreamOIDCIdentityProvider) AllowsPasswordGrant() bool { return u.AllowPasswordGrant } func (u *TestUpstreamOIDCIdentityProvider) PasswordCredentialsGrantAndValidateTokens(ctx context.Context, username, password string) (*oidctypes.Token, error) { u.passwordCredentialsGrantAndValidateTokensCallCount++ u.passwordCredentialsGrantAndValidateTokensArgs = append(u.passwordCredentialsGrantAndValidateTokensArgs, &PasswordCredentialsGrantAndValidateTokensArgs{ Ctx: ctx, Username: username, Password: password, }) return u.PasswordCredentialsGrantAndValidateTokensFunc(ctx, username, password) } func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( ctx context.Context, authcode string, pkceCodeVerifier oidcpkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string, ) (*oidctypes.Token, error) { if u.exchangeAuthcodeAndValidateTokensArgs == nil { u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) } u.exchangeAuthcodeAndValidateTokensCallCount++ u.exchangeAuthcodeAndValidateTokensArgs = append(u.exchangeAuthcodeAndValidateTokensArgs, &ExchangeAuthcodeAndValidateTokenArgs{ Ctx: ctx, Authcode: authcode, PKCECodeVerifier: pkceCodeVerifier, ExpectedIDTokenNonce: expectedIDTokenNonce, RedirectURI: redirectURI, }) return u.ExchangeAuthcodeAndValidateTokensFunc(ctx, authcode, pkceCodeVerifier, expectedIDTokenNonce) } func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensCallCount() int { return u.exchangeAuthcodeAndValidateTokensCallCount } func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs(call int) *ExchangeAuthcodeAndValidateTokenArgs { if u.exchangeAuthcodeAndValidateTokensArgs == nil { u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) } return u.exchangeAuthcodeAndValidateTokensArgs[call] } func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error) { if u.performRefreshArgs == nil { u.performRefreshArgs = make([]*PerformRefreshArgs, 0) } u.performRefreshCallCount++ u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{ Ctx: ctx, RefreshToken: refreshToken, }) return u.PerformRefreshFunc(ctx, refreshToken) } func (u *TestUpstreamOIDCIdentityProvider) RevokeToken(ctx context.Context, token string, tokenType upstreamprovider.RevocableTokenType) error { if u.revokeTokenArgs == nil { u.revokeTokenArgs = make([]*RevokeTokenArgs, 0) } u.revokeTokenCallCount++ u.revokeTokenArgs = append(u.revokeTokenArgs, &RevokeTokenArgs{ Ctx: ctx, Token: token, TokenType: tokenType, }) return u.RevokeTokenFunc(ctx, token, tokenType) } func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshCallCount() int { return u.performRefreshCallCount } func (u *TestUpstreamOIDCIdentityProvider) PerformRefreshArgs(call int) *PerformRefreshArgs { if u.performRefreshArgs == nil { u.performRefreshArgs = make([]*PerformRefreshArgs, 0) } return u.performRefreshArgs[call] } func (u *TestUpstreamOIDCIdentityProvider) RevokeTokenCallCount() int { return u.performRefreshCallCount } func (u *TestUpstreamOIDCIdentityProvider) RevokeTokenArgs(call int) *RevokeTokenArgs { if u.revokeTokenArgs == nil { u.revokeTokenArgs = make([]*RevokeTokenArgs, 0) } return u.revokeTokenArgs[call] } func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool, requireUserInfo bool) (*oidctypes.Token, error) { if u.validateTokenAndMergeWithUserInfoArgs == nil { u.validateTokenAndMergeWithUserInfoArgs = make([]*ValidateTokenAndMergeWithUserInfoArgs, 0) } u.validateTokenAndMergeWithUserInfoCallCount++ u.validateTokenAndMergeWithUserInfoArgs = append(u.validateTokenAndMergeWithUserInfoArgs, &ValidateTokenAndMergeWithUserInfoArgs{ Ctx: ctx, Tok: tok, ExpectedIDTokenNonce: expectedIDTokenNonce, RequireIDToken: requireIDToken, RequireUserInfo: requireUserInfo, }) return u.ValidateTokenAndMergeWithUserInfoFunc(ctx, tok, expectedIDTokenNonce) } func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfoCallCount() int { return u.validateTokenAndMergeWithUserInfoCallCount } func (u *TestUpstreamOIDCIdentityProvider) ValidateTokenAndMergeWithUserInfoArgs(call int) *ValidateTokenAndMergeWithUserInfoArgs { if u.validateTokenAndMergeWithUserInfoArgs == nil { u.validateTokenAndMergeWithUserInfoArgs = make([]*ValidateTokenAndMergeWithUserInfoArgs, 0) } return u.validateTokenAndMergeWithUserInfoArgs[call] } type TestFederationDomainIdentityProvidersListerFinder struct { upstreamOIDCIdentityProviders []*TestUpstreamOIDCIdentityProvider upstreamLDAPIdentityProviders []*TestUpstreamLDAPIdentityProvider upstreamActiveDirectoryIdentityProviders []*TestUpstreamLDAPIdentityProvider defaultIDPDisplayName string } func (t *TestFederationDomainIdentityProvidersListerFinder) GetOIDCIdentityProviders() []*resolvedprovider.FederationDomainResolvedOIDCIdentityProvider { fdIDPs := make([]*resolvedprovider.FederationDomainResolvedOIDCIdentityProvider, len(t.upstreamOIDCIdentityProviders)) for i, testIDP := range t.upstreamOIDCIdentityProviders { fdIDP := &resolvedprovider.FederationDomainResolvedOIDCIdentityProvider{ DisplayName: testIDP.DisplayNameForFederationDomain, Provider: testIDP, SessionProviderType: psession.ProviderTypeOIDC, Transforms: testIDP.TransformsForFederationDomain, } fdIDPs[i] = fdIDP } return fdIDPs } func (t *TestFederationDomainIdentityProvidersListerFinder) GetLDAPIdentityProviders() []*resolvedprovider.FederationDomainResolvedLDAPIdentityProvider { fdIDPs := make([]*resolvedprovider.FederationDomainResolvedLDAPIdentityProvider, len(t.upstreamLDAPIdentityProviders)) for i, testIDP := range t.upstreamLDAPIdentityProviders { fdIDP := &resolvedprovider.FederationDomainResolvedLDAPIdentityProvider{ DisplayName: testIDP.DisplayNameForFederationDomain, Provider: testIDP, SessionProviderType: psession.ProviderTypeLDAP, Transforms: testIDP.TransformsForFederationDomain, } fdIDPs[i] = fdIDP } return fdIDPs } func (t *TestFederationDomainIdentityProvidersListerFinder) GetActiveDirectoryIdentityProviders() []*resolvedprovider.FederationDomainResolvedLDAPIdentityProvider { fdIDPs := make([]*resolvedprovider.FederationDomainResolvedLDAPIdentityProvider, len(t.upstreamActiveDirectoryIdentityProviders)) for i, testIDP := range t.upstreamActiveDirectoryIdentityProviders { fdIDP := &resolvedprovider.FederationDomainResolvedLDAPIdentityProvider{ DisplayName: testIDP.DisplayNameForFederationDomain, Provider: testIDP, SessionProviderType: psession.ProviderTypeActiveDirectory, Transforms: testIDP.TransformsForFederationDomain, } fdIDPs[i] = fdIDP } return fdIDPs } func (t *TestFederationDomainIdentityProvidersListerFinder) FindDefaultIDP() (*resolvedprovider.FederationDomainResolvedOIDCIdentityProvider, *resolvedprovider.FederationDomainResolvedLDAPIdentityProvider, error) { if t.defaultIDPDisplayName == "" { return nil, nil, fmt.Errorf("identity provider not found: this federation domain does not have a default identity provider") } return t.FindUpstreamIDPByDisplayName(t.defaultIDPDisplayName) } func (t *TestFederationDomainIdentityProvidersListerFinder) FindUpstreamIDPByDisplayName(upstreamIDPDisplayName string) (*resolvedprovider.FederationDomainResolvedOIDCIdentityProvider, *resolvedprovider.FederationDomainResolvedLDAPIdentityProvider, error) { for _, testIDP := range t.upstreamOIDCIdentityProviders { if upstreamIDPDisplayName == testIDP.DisplayNameForFederationDomain { return &resolvedprovider.FederationDomainResolvedOIDCIdentityProvider{ DisplayName: testIDP.DisplayNameForFederationDomain, Provider: testIDP, SessionProviderType: psession.ProviderTypeOIDC, Transforms: testIDP.TransformsForFederationDomain, }, nil, nil } } for _, testIDP := range t.upstreamLDAPIdentityProviders { if upstreamIDPDisplayName == testIDP.DisplayNameForFederationDomain { return nil, &resolvedprovider.FederationDomainResolvedLDAPIdentityProvider{ DisplayName: testIDP.DisplayNameForFederationDomain, Provider: testIDP, SessionProviderType: psession.ProviderTypeLDAP, Transforms: testIDP.TransformsForFederationDomain, }, nil } } for _, testIDP := range t.upstreamActiveDirectoryIdentityProviders { if upstreamIDPDisplayName == testIDP.DisplayNameForFederationDomain { return nil, &resolvedprovider.FederationDomainResolvedLDAPIdentityProvider{ DisplayName: testIDP.DisplayNameForFederationDomain, Provider: testIDP, SessionProviderType: psession.ProviderTypeActiveDirectory, Transforms: testIDP.TransformsForFederationDomain, }, nil } } return nil, nil, fmt.Errorf("did not find IDP with name %q", upstreamIDPDisplayName) } func (t *TestFederationDomainIdentityProvidersListerFinder) SetOIDCIdentityProviders(providers []*TestUpstreamOIDCIdentityProvider) { t.upstreamOIDCIdentityProviders = providers } func (t *TestFederationDomainIdentityProvidersListerFinder) SetLDAPIdentityProviders(providers []*TestUpstreamLDAPIdentityProvider) { t.upstreamLDAPIdentityProviders = providers } func (t *TestFederationDomainIdentityProvidersListerFinder) SetActiveDirectoryIdentityProviders(providers []*TestUpstreamLDAPIdentityProvider) { t.upstreamActiveDirectoryIdentityProviders = providers } type UpstreamIDPListerBuilder struct { upstreamOIDCIdentityProviders []*TestUpstreamOIDCIdentityProvider upstreamLDAPIdentityProviders []*TestUpstreamLDAPIdentityProvider upstreamActiveDirectoryIdentityProviders []*TestUpstreamLDAPIdentityProvider defaultIDPDisplayName string } func (b *UpstreamIDPListerBuilder) WithOIDC(upstreamOIDCIdentityProviders ...*TestUpstreamOIDCIdentityProvider) *UpstreamIDPListerBuilder { b.upstreamOIDCIdentityProviders = append(b.upstreamOIDCIdentityProviders, upstreamOIDCIdentityProviders...) return b } func (b *UpstreamIDPListerBuilder) WithLDAP(upstreamLDAPIdentityProviders ...*TestUpstreamLDAPIdentityProvider) *UpstreamIDPListerBuilder { b.upstreamLDAPIdentityProviders = append(b.upstreamLDAPIdentityProviders, upstreamLDAPIdentityProviders...) return b } func (b *UpstreamIDPListerBuilder) WithActiveDirectory(upstreamActiveDirectoryIdentityProviders ...*TestUpstreamLDAPIdentityProvider) *UpstreamIDPListerBuilder { b.upstreamActiveDirectoryIdentityProviders = append(b.upstreamActiveDirectoryIdentityProviders, upstreamActiveDirectoryIdentityProviders...) return b } func (b *UpstreamIDPListerBuilder) WithDefaultIDPDisplayName(defaultIDPDisplayName string) *UpstreamIDPListerBuilder { b.defaultIDPDisplayName = defaultIDPDisplayName return b } func (b *UpstreamIDPListerBuilder) BuildFederationDomainIdentityProvidersListerFinder() *TestFederationDomainIdentityProvidersListerFinder { return &TestFederationDomainIdentityProvidersListerFinder{ upstreamOIDCIdentityProviders: b.upstreamOIDCIdentityProviders, upstreamLDAPIdentityProviders: b.upstreamLDAPIdentityProviders, upstreamActiveDirectoryIdentityProviders: b.upstreamActiveDirectoryIdentityProviders, defaultIDPDisplayName: b.defaultIDPDisplayName, } } func (b *UpstreamIDPListerBuilder) BuildDynamicUpstreamIDPProvider() dynamicupstreamprovider.DynamicUpstreamIDPProvider { idpProvider := dynamicupstreamprovider.NewDynamicUpstreamIDPProvider() oidcUpstreams := make([]upstreamprovider.UpstreamOIDCIdentityProviderI, len(b.upstreamOIDCIdentityProviders)) for i := range b.upstreamOIDCIdentityProviders { oidcUpstreams[i] = upstreamprovider.UpstreamOIDCIdentityProviderI(b.upstreamOIDCIdentityProviders[i]) } idpProvider.SetOIDCIdentityProviders(oidcUpstreams) ldapUpstreams := make([]upstreamprovider.UpstreamLDAPIdentityProviderI, len(b.upstreamLDAPIdentityProviders)) for i := range b.upstreamLDAPIdentityProviders { ldapUpstreams[i] = upstreamprovider.UpstreamLDAPIdentityProviderI(b.upstreamLDAPIdentityProviders[i]) } idpProvider.SetLDAPIdentityProviders(ldapUpstreams) adUpstreams := make([]upstreamprovider.UpstreamLDAPIdentityProviderI, len(b.upstreamActiveDirectoryIdentityProviders)) for i := range b.upstreamActiveDirectoryIdentityProviders { adUpstreams[i] = upstreamprovider.UpstreamLDAPIdentityProviderI(b.upstreamActiveDirectoryIdentityProviders[i]) } idpProvider.SetActiveDirectoryIdentityProviders(adUpstreams) return idpProvider } func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToPasswordCredentialsGrantAndValidateTokens( t *testing.T, expectedPerformedByUpstreamName string, expectedArgs *PasswordCredentialsGrantAndValidateTokensArgs, ) { t.Helper() var actualArgs *PasswordCredentialsGrantAndValidateTokensArgs var actualNameOfUpstreamWhichMadeCall string actualCallCountAcrossAllOIDCUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { callCountOnThisUpstream := upstreamOIDC.passwordCredentialsGrantAndValidateTokensCallCount actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream if callCountOnThisUpstream == 1 { actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name actualArgs = upstreamOIDC.passwordCredentialsGrantAndValidateTokensArgs[0] } } require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, "should have been exactly one call to PasswordCredentialsGrantAndValidateTokens() by all OIDC upstreams", ) require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, "PasswordCredentialsGrantAndValidateTokens() was called on the wrong OIDC upstream", ) require.Equal(t, expectedArgs, actualArgs) } func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPasswordCredentialsGrantAndValidateTokens(t *testing.T) { t.Helper() actualCallCountAcrossAllOIDCUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.passwordCredentialsGrantAndValidateTokensCallCount } require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, "expected exactly zero calls to PasswordCredentialsGrantAndValidateTokens()", ) } func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToExchangeAuthcodeAndValidateTokens( t *testing.T, expectedPerformedByUpstreamName string, expectedArgs *ExchangeAuthcodeAndValidateTokenArgs, ) { t.Helper() var actualArgs *ExchangeAuthcodeAndValidateTokenArgs var actualNameOfUpstreamWhichMadeCall string actualCallCountAcrossAllOIDCUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { callCountOnThisUpstream := upstreamOIDC.exchangeAuthcodeAndValidateTokensCallCount actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream if callCountOnThisUpstream == 1 { actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name actualArgs = upstreamOIDC.exchangeAuthcodeAndValidateTokensArgs[0] } } require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, "should have been exactly one call to ExchangeAuthcodeAndValidateTokens() by all OIDC upstreams", ) require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, "ExchangeAuthcodeAndValidateTokens() was called on the wrong OIDC upstream", ) require.Equal(t, expectedArgs, actualArgs) } func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToExchangeAuthcodeAndValidateTokens(t *testing.T) { t.Helper() actualCallCountAcrossAllOIDCUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.exchangeAuthcodeAndValidateTokensCallCount } require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, "expected exactly zero calls to ExchangeAuthcodeAndValidateTokens()", ) } func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToPerformRefresh( t *testing.T, expectedPerformedByUpstreamName string, expectedArgs *PerformRefreshArgs, ) { t.Helper() var actualArgs *PerformRefreshArgs var actualNameOfUpstreamWhichMadeCall string actualCallCountAcrossAllUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { callCountOnThisUpstream := upstreamOIDC.performRefreshCallCount actualCallCountAcrossAllUpstreams += callCountOnThisUpstream if callCountOnThisUpstream == 1 { actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name actualArgs = upstreamOIDC.performRefreshArgs[0] } } for _, upstreamLDAP := range b.upstreamLDAPIdentityProviders { callCountOnThisUpstream := upstreamLDAP.performRefreshCallCount actualCallCountAcrossAllUpstreams += callCountOnThisUpstream if callCountOnThisUpstream == 1 { actualNameOfUpstreamWhichMadeCall = upstreamLDAP.Name actualArgs = upstreamLDAP.performRefreshArgs[0] } } for _, upstreamAD := range b.upstreamActiveDirectoryIdentityProviders { callCountOnThisUpstream := upstreamAD.performRefreshCallCount actualCallCountAcrossAllUpstreams += callCountOnThisUpstream if callCountOnThisUpstream == 1 { actualNameOfUpstreamWhichMadeCall = upstreamAD.Name actualArgs = upstreamAD.performRefreshArgs[0] } } require.Equal(t, 1, actualCallCountAcrossAllUpstreams, "should have been exactly one call to PerformRefresh() by all upstreams", ) require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, "PerformRefresh() was called on the wrong upstream", ) require.Equal(t, expectedArgs, actualArgs) } func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToPerformRefresh(t *testing.T) { t.Helper() actualCallCountAcrossAllUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { actualCallCountAcrossAllUpstreams += upstreamOIDC.performRefreshCallCount } for _, upstreamLDAP := range b.upstreamLDAPIdentityProviders { actualCallCountAcrossAllUpstreams += upstreamLDAP.performRefreshCallCount } for _, upstreamActiveDirectory := range b.upstreamActiveDirectoryIdentityProviders { actualCallCountAcrossAllUpstreams += upstreamActiveDirectory.performRefreshCallCount } require.Equal(t, 0, actualCallCountAcrossAllUpstreams, "expected exactly zero calls to PerformRefresh()", ) } func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToValidateToken( t *testing.T, expectedPerformedByUpstreamName string, expectedArgs *ValidateTokenAndMergeWithUserInfoArgs, ) { t.Helper() var actualArgs *ValidateTokenAndMergeWithUserInfoArgs var actualNameOfUpstreamWhichMadeCall string actualCallCountAcrossAllOIDCUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { callCountOnThisUpstream := upstreamOIDC.validateTokenAndMergeWithUserInfoCallCount actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream if callCountOnThisUpstream == 1 { actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name actualArgs = upstreamOIDC.validateTokenAndMergeWithUserInfoArgs[0] } } require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, "should have been exactly one call to ValidateTokenAndMergeWithUserInfo() by all OIDC upstreams", ) require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, "ValidateTokenAndMergeWithUserInfo() was called on the wrong OIDC upstream", ) require.Equal(t, expectedArgs, actualArgs) } func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToValidateToken(t *testing.T) { t.Helper() actualCallCountAcrossAllOIDCUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.validateTokenAndMergeWithUserInfoCallCount } require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, "expected exactly zero calls to ValidateTokenAndMergeWithUserInfo()", ) } func (b *UpstreamIDPListerBuilder) RequireExactlyOneCallToRevokeToken( t *testing.T, expectedPerformedByUpstreamName string, expectedArgs *RevokeTokenArgs, ) { t.Helper() var actualArgs *RevokeTokenArgs var actualNameOfUpstreamWhichMadeCall string actualCallCountAcrossAllOIDCUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { callCountOnThisUpstream := upstreamOIDC.revokeTokenCallCount actualCallCountAcrossAllOIDCUpstreams += callCountOnThisUpstream if callCountOnThisUpstream == 1 { actualNameOfUpstreamWhichMadeCall = upstreamOIDC.Name actualArgs = upstreamOIDC.revokeTokenArgs[0] } } require.Equal(t, 1, actualCallCountAcrossAllOIDCUpstreams, "should have been exactly one call to RevokeToken() by all OIDC upstreams", ) require.Equal(t, expectedPerformedByUpstreamName, actualNameOfUpstreamWhichMadeCall, "RevokeToken() was called on the wrong OIDC upstream", ) require.Equal(t, expectedArgs, actualArgs) } func (b *UpstreamIDPListerBuilder) RequireExactlyZeroCallsToRevokeToken(t *testing.T) { t.Helper() actualCallCountAcrossAllOIDCUpstreams := 0 for _, upstreamOIDC := range b.upstreamOIDCIdentityProviders { actualCallCountAcrossAllOIDCUpstreams += upstreamOIDC.revokeTokenCallCount } require.Equal(t, 0, actualCallCountAcrossAllOIDCUpstreams, "expected exactly zero calls to RevokeToken()", ) } func NewUpstreamIDPListerBuilder() *UpstreamIDPListerBuilder { return &UpstreamIDPListerBuilder{} } type TestUpstreamOIDCIdentityProviderBuilder struct { name string resourceUID types.UID clientID string scopes []string idToken map[string]interface{} refreshToken *oidctypes.RefreshToken accessToken *oidctypes.AccessToken usernameClaim string groupsClaim string refreshedTokens *oauth2.Token validatedAndMergedWithUserInfoTokens *oidctypes.Token authorizationURL url.URL hasUserInfoURL bool additionalAuthcodeParams map[string]string additionalClaimMappings map[string]string allowPasswordGrant bool authcodeExchangeErr error passwordGrantErr error performRefreshErr error revokeTokenErr error validateTokenAndMergeWithUserInfoErr error displayNameForFederationDomain string transformsForFederationDomain *idtransform.TransformationPipeline } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithName(value string) *TestUpstreamOIDCIdentityProviderBuilder { u.name = value return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithResourceUID(value types.UID) *TestUpstreamOIDCIdentityProviderBuilder { u.resourceUID = value return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithClientID(value string) *TestUpstreamOIDCIdentityProviderBuilder { u.clientID = value return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithAuthorizationURL(value url.URL) *TestUpstreamOIDCIdentityProviderBuilder { u.authorizationURL = value return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithUserInfoURL() *TestUpstreamOIDCIdentityProviderBuilder { u.hasUserInfoURL = true return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithoutUserInfoURL() *TestUpstreamOIDCIdentityProviderBuilder { u.hasUserInfoURL = false return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithAllowPasswordGrant(value bool) *TestUpstreamOIDCIdentityProviderBuilder { u.allowPasswordGrant = value return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithScopes(values []string) *TestUpstreamOIDCIdentityProviderBuilder { u.scopes = values return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithUsernameClaim(value string) *TestUpstreamOIDCIdentityProviderBuilder { u.usernameClaim = value return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithoutUsernameClaim() *TestUpstreamOIDCIdentityProviderBuilder { u.usernameClaim = "" return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithGroupsClaim(value string) *TestUpstreamOIDCIdentityProviderBuilder { u.groupsClaim = value return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithoutGroupsClaim() *TestUpstreamOIDCIdentityProviderBuilder { u.groupsClaim = "" return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithIDTokenClaim(name string, value interface{}) *TestUpstreamOIDCIdentityProviderBuilder { if u.idToken == nil { u.idToken = map[string]interface{}{} } u.idToken[name] = value return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithoutIDTokenClaim(claim string) *TestUpstreamOIDCIdentityProviderBuilder { delete(u.idToken, claim) return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithAdditionalAuthcodeParams(params map[string]string) *TestUpstreamOIDCIdentityProviderBuilder { u.additionalAuthcodeParams = params return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithAdditionalClaimMappings(m map[string]string) *TestUpstreamOIDCIdentityProviderBuilder { u.additionalClaimMappings = m return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithRefreshToken(token string) *TestUpstreamOIDCIdentityProviderBuilder { u.refreshToken = &oidctypes.RefreshToken{Token: token} return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithEmptyRefreshToken() *TestUpstreamOIDCIdentityProviderBuilder { u.refreshToken = &oidctypes.RefreshToken{Token: ""} return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithoutRefreshToken() *TestUpstreamOIDCIdentityProviderBuilder { u.refreshToken = nil return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithAccessToken(token string, expiry metav1.Time) *TestUpstreamOIDCIdentityProviderBuilder { u.accessToken = &oidctypes.AccessToken{Token: token, Expiry: expiry} return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithEmptyAccessToken() *TestUpstreamOIDCIdentityProviderBuilder { u.accessToken = &oidctypes.AccessToken{Token: ""} return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithoutAccessToken() *TestUpstreamOIDCIdentityProviderBuilder { u.accessToken = nil return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithUpstreamAuthcodeExchangeError(err error) *TestUpstreamOIDCIdentityProviderBuilder { u.authcodeExchangeErr = err return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithPasswordGrantError(err error) *TestUpstreamOIDCIdentityProviderBuilder { u.passwordGrantErr = err return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithRefreshedTokens(tokens *oauth2.Token) *TestUpstreamOIDCIdentityProviderBuilder { u.refreshedTokens = tokens return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithPerformRefreshError(err error) *TestUpstreamOIDCIdentityProviderBuilder { u.performRefreshErr = err return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidatedAndMergedWithUserInfoTokens(tokens *oidctypes.Token) *TestUpstreamOIDCIdentityProviderBuilder { u.validatedAndMergedWithUserInfoTokens = tokens return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithValidateTokenAndMergeWithUserInfoError(err error) *TestUpstreamOIDCIdentityProviderBuilder { u.validateTokenAndMergeWithUserInfoErr = err return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithRevokeTokenError(err error) *TestUpstreamOIDCIdentityProviderBuilder { u.revokeTokenErr = err return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithDisplayNameForFederationDomain(displayName string) *TestUpstreamOIDCIdentityProviderBuilder { u.displayNameForFederationDomain = displayName return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) WithTransformsForFederationDomain(transforms *idtransform.TransformationPipeline) *TestUpstreamOIDCIdentityProviderBuilder { u.transformsForFederationDomain = transforms return u } func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdentityProvider { if u.displayNameForFederationDomain == "" { // default it to the CR name u.displayNameForFederationDomain = u.name } if u.transformsForFederationDomain == nil { // default to an empty pipeline u.transformsForFederationDomain = idtransform.NewTransformationPipeline() } return &TestUpstreamOIDCIdentityProvider{ Name: u.name, ClientID: u.clientID, ResourceUID: u.resourceUID, UsernameClaim: u.usernameClaim, GroupsClaim: u.groupsClaim, Scopes: u.scopes, AllowPasswordGrant: u.allowPasswordGrant, AuthorizationURL: u.authorizationURL, UserInfoURL: u.hasUserInfoURL, AdditionalAuthcodeParams: u.additionalAuthcodeParams, AdditionalClaimMappings: u.additionalClaimMappings, DisplayNameForFederationDomain: u.displayNameForFederationDomain, TransformsForFederationDomain: u.transformsForFederationDomain, ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier oidcpkce.Code, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { if u.authcodeExchangeErr != nil { return nil, u.authcodeExchangeErr } return &oidctypes.Token{IDToken: &oidctypes.IDToken{Claims: u.idToken}, RefreshToken: u.refreshToken, AccessToken: u.accessToken}, nil }, PasswordCredentialsGrantAndValidateTokensFunc: func(ctx context.Context, username, password string) (*oidctypes.Token, error) { if u.passwordGrantErr != nil { return nil, u.passwordGrantErr } return &oidctypes.Token{IDToken: &oidctypes.IDToken{Claims: u.idToken}, RefreshToken: u.refreshToken, AccessToken: u.accessToken}, nil }, PerformRefreshFunc: func(ctx context.Context, refreshToken string) (*oauth2.Token, error) { if u.performRefreshErr != nil { return nil, u.performRefreshErr } return u.refreshedTokens, nil }, RevokeTokenFunc: func(ctx context.Context, refreshToken string, tokenType upstreamprovider.RevocableTokenType) error { return u.revokeTokenErr }, ValidateTokenAndMergeWithUserInfoFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { if u.validateTokenAndMergeWithUserInfoErr != nil { return nil, u.validateTokenAndMergeWithUserInfoErr } return u.validatedAndMergedWithUserInfoTokens, nil }, } } func NewTestUpstreamOIDCIdentityProviderBuilder() *TestUpstreamOIDCIdentityProviderBuilder { return &TestUpstreamOIDCIdentityProviderBuilder{} } // Declare a separate type from the production code to ensure that the state param's contents was serialized // in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of // the serialized fields is the same, which doesn't really matter expect that we can make simpler equality // assertions about the redirect URL in this test. type ExpectedUpstreamStateParamFormat struct { P string `json:"p"` U string `json:"u"` T string `json:"t"` N string `json:"n"` C string `json:"c"` K string `json:"k"` V string `json:"v"` } type UpstreamStateParamBuilder ExpectedUpstreamStateParamFormat func (b UpstreamStateParamBuilder) Build(t *testing.T, stateEncoder *securecookie.SecureCookie) string { state, err := stateEncoder.Encode("s", b) require.NoError(t, err) return state } func (b *UpstreamStateParamBuilder) WithAuthorizeRequestParams(params string) *UpstreamStateParamBuilder { b.P = params return b } func (b *UpstreamStateParamBuilder) WithNonce(nonce string) *UpstreamStateParamBuilder { b.N = nonce return b } func (b *UpstreamStateParamBuilder) WithCSRF(csrf string) *UpstreamStateParamBuilder { b.C = csrf return b } func (b *UpstreamStateParamBuilder) WithPKCE(pkce string) *UpstreamStateParamBuilder { b.K = pkce return b } func (b *UpstreamStateParamBuilder) WithUpstreamIDPType(upstreamIDPType string) *UpstreamStateParamBuilder { b.T = upstreamIDPType return b } func (b *UpstreamStateParamBuilder) WithStateVersion(version string) *UpstreamStateParamBuilder { b.V = version return b } type staticKeySet struct { publicKey crypto.PublicKey } func newStaticKeySet(publicKey crypto.PublicKey) coreosoidc.KeySet { return &staticKeySet{publicKey} } func (s *staticKeySet) VerifySignature(_ context.Context, jwt string) ([]byte, error) { jws, err := jose.ParseSigned(jwt) if err != nil { return nil, fmt.Errorf("oidc: malformed jwt: %w", err) } return jws.Verify(s.publicKey) } // VerifyECDSAIDToken verifies that the provided idToken was issued via the provided jwtSigningKey. // It also performs some light validation on the claims, i.e., it makes sure the provided idToken // has the provided issuer and clientID. // // Further validation can be done via callers via the returned coreosoidc.IDToken. func VerifyECDSAIDToken( t *testing.T, issuer, clientID string, jwtSigningKey *ecdsa.PrivateKey, idToken string, ) *coreosoidc.IDToken { t.Helper() keySet := newStaticKeySet(jwtSigningKey.Public()) verifyConfig := coreosoidc.Config{ClientID: clientID, SupportedSigningAlgs: []string{coreosoidc.ES256}} verifier := coreosoidc.NewVerifier(issuer, keySet, &verifyConfig) token, err := verifier.Verify(context.Background(), idToken) require.NoError(t, err) return token } func RequireAuthCodeRegexpMatch( t *testing.T, actualContent string, wantRegexp string, kubeClient *fake.Clientset, secretsClient v1.SecretInterface, oauthStore fositestoragei.AllFositeStorage, wantDownstreamGrantedScopes []string, wantDownstreamIDTokenSubject string, wantDownstreamIDTokenUsername string, wantDownstreamIDTokenGroups []string, wantDownstreamRequestedScopes []string, wantDownstreamPKCEChallenge string, wantDownstreamPKCEChallengeMethod string, wantDownstreamNonce string, wantDownstreamClientID string, wantDownstreamRedirectURI string, wantCustomSessionData *psession.CustomSessionData, wantDownstreamAdditionalClaims map[string]interface{}, ) { t.Helper() // Assert that Location header matches regular expression. regex := regexp.MustCompile(wantRegexp) submatches := regex.FindStringSubmatch(actualContent) require.Lenf(t, submatches, 2, "no regexp match in actualContent: %", actualContent) capturedAuthCode := submatches[1] // Authcodes should start with the custom prefix "pin_ac_" to make them identifiable as authcodes when seen by a user out of context. require.True(t, strings.HasPrefix(capturedAuthCode, "pin_ac_"), "token %q did not have expected prefix 'pin_ac_'", capturedAuthCode) // fosite authcodes are in the format `data.signature`, so grab the signature part, which is the lookup key in the storage interface authcodeDataAndSignature := strings.Split(capturedAuthCode, ".") require.Len(t, authcodeDataAndSignature, 2) // Several Secrets should have been created expectedNumberOfCreatedSecrets := 2 if includesOpenIDScope(wantDownstreamGrantedScopes) { expectedNumberOfCreatedSecrets++ } require.Len(t, FilterClientSecretCreateActions(kubeClient.Actions()), expectedNumberOfCreatedSecrets) // One authcode should have been stored. testutil.RequireNumberOfSecretsMatchingLabelSelector(t, secretsClient, labels.Set{crud.SecretLabelKey: authorizationcode.TypeLabelValue}, 1) storedRequestFromAuthcode, storedSessionFromAuthcode := validateAuthcodeStorage( t, oauthStore, authcodeDataAndSignature[1], // Authcode store key is authcode signature wantDownstreamGrantedScopes, wantDownstreamIDTokenSubject, wantDownstreamIDTokenUsername, wantDownstreamIDTokenGroups, wantDownstreamRequestedScopes, wantDownstreamClientID, wantDownstreamRedirectURI, wantCustomSessionData, wantDownstreamAdditionalClaims, ) // One PKCE should have been stored. testutil.RequireNumberOfSecretsMatchingLabelSelector(t, secretsClient, labels.Set{crud.SecretLabelKey: pkce.TypeLabelValue}, 1) validatePKCEStorage( t, oauthStore, authcodeDataAndSignature[1], // PKCE store key is authcode signature storedRequestFromAuthcode, storedSessionFromAuthcode, wantDownstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod, ) // One IDSession should have been stored, if the downstream actually requested the "openid" scope if includesOpenIDScope(wantDownstreamGrantedScopes) { testutil.RequireNumberOfSecretsMatchingLabelSelector(t, secretsClient, labels.Set{crud.SecretLabelKey: openidconnect.TypeLabelValue}, 1) validateIDSessionStorage( t, oauthStore, capturedAuthCode, // IDSession store key is full authcode storedRequestFromAuthcode, storedSessionFromAuthcode, wantDownstreamNonce, ) } } func includesOpenIDScope(scopes []string) bool { for _, scope := range scopes { if scope == "openid" { return true } } return false } //nolint:funlen func validateAuthcodeStorage( t *testing.T, oauthStore fositestoragei.AllFositeStorage, storeKey string, wantDownstreamGrantedScopes []string, wantDownstreamIDTokenSubject string, wantDownstreamIDTokenUsername string, wantDownstreamIDTokenGroups []string, wantDownstreamRequestedScopes []string, wantDownstreamClientID string, wantDownstreamRedirectURI string, wantCustomSessionData *psession.CustomSessionData, wantDownstreamAdditionalClaims map[string]interface{}, ) (*fosite.Request, *psession.PinnipedSession) { t.Helper() const ( authCodeExpirationSeconds = 10 * 60 // Currently, we set our auth code expiration to 10 minutes timeComparisonFudgeFactor = time.Second * 15 ) // Get the authcode session back from storage so we can require that it was stored correctly. storedAuthorizeRequestFromAuthcode, err := oauthStore.GetAuthorizeCodeSession(context.Background(), storeKey, nil) require.NoError(t, err) // Check that storage returned the expected concrete data types. storedRequestFromAuthcode, storedSessionFromAuthcode := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromAuthcode) // Check which scopes were granted. require.ElementsMatch(t, wantDownstreamGrantedScopes, storedRequestFromAuthcode.GetGrantedScopes()) // Check all the other fields of the stored request. require.NotEmpty(t, storedRequestFromAuthcode.ID) require.Equal(t, wantDownstreamClientID, storedRequestFromAuthcode.Client.GetID()) require.ElementsMatch(t, wantDownstreamRequestedScopes, storedRequestFromAuthcode.RequestedScope) require.Nil(t, storedRequestFromAuthcode.RequestedAudience) require.Empty(t, storedRequestFromAuthcode.GrantedAudience) require.Equal(t, url.Values{"redirect_uri": []string{wantDownstreamRedirectURI}}, storedRequestFromAuthcode.Form) testutil.RequireTimeInDelta(t, time.Now(), storedRequestFromAuthcode.RequestedAt, timeComparisonFudgeFactor) // We're not using these fields yet, so confirm that we did not set them (for now). require.Empty(t, storedSessionFromAuthcode.Fosite.Subject) require.Empty(t, storedSessionFromAuthcode.Fosite.Username) require.Empty(t, storedSessionFromAuthcode.Fosite.Headers) // The authcode that we are issuing should be good for the length of time that we declare in the fosite config. testutil.RequireTimeInDelta(t, time.Now().Add(authCodeExpirationSeconds*time.Second), storedSessionFromAuthcode.Fosite.ExpiresAt[fosite.AuthorizeCode], timeComparisonFudgeFactor) require.Len(t, storedSessionFromAuthcode.Fosite.ExpiresAt, 1) // Now confirm the ID token claims. actualClaims := storedSessionFromAuthcode.Fosite.Claims // Should always have an azp claim. require.Equal(t, wantDownstreamClientID, actualClaims.Extra["azp"]) wantDownstreamIDTokenExtraClaimsCount := 1 // should always have azp claim if len(wantDownstreamAdditionalClaims) > 0 { wantDownstreamIDTokenExtraClaimsCount++ } // Check the user's identity, which are put into the downstream ID token's subject, username and groups claims. require.Equal(t, wantDownstreamIDTokenSubject, actualClaims.Subject) if wantDownstreamIDTokenUsername == "" { require.NotContains(t, actualClaims.Extra, "username") } else { wantDownstreamIDTokenExtraClaimsCount++ // should also have username claim require.Equal(t, wantDownstreamIDTokenUsername, actualClaims.Extra["username"]) } if slices.Contains(wantDownstreamGrantedScopes, "groups") { wantDownstreamIDTokenExtraClaimsCount++ // should also have groups claim actualDownstreamIDTokenGroups := actualClaims.Extra["groups"] require.NotNil(t, actualDownstreamIDTokenGroups) require.ElementsMatch(t, wantDownstreamIDTokenGroups, actualDownstreamIDTokenGroups) } else { require.Emptyf(t, wantDownstreamIDTokenGroups, "test case did not want the groups scope to be granted, "+ "but wanted something in the groups claim, which doesn't make sense. please review the test case's expectations.") actualDownstreamIDTokenGroups := actualClaims.Extra["groups"] require.Nil(t, actualDownstreamIDTokenGroups) } if len(wantDownstreamAdditionalClaims) > 0 { actualAdditionalClaims, ok := actualClaims.Get("additionalClaims").(map[string]interface{}) require.True(t, ok, "expected additionalClaims to be a map[string]interface{}") require.Equal(t, wantDownstreamAdditionalClaims, actualAdditionalClaims) } else { require.NotContains(t, actualClaims.Extra, "additionalClaims", "additionalClaims must not be present when there are no wanted additional claims") } // Make sure that we asserted on every extra claim. require.Len(t, actualClaims.Extra, wantDownstreamIDTokenExtraClaimsCount) // Check the rest of the downstream ID token's claims. Fosite wants us to set these (in UTC time). testutil.RequireTimeInDelta(t, time.Now().UTC(), actualClaims.RequestedAt, timeComparisonFudgeFactor) testutil.RequireTimeInDelta(t, time.Now().UTC(), actualClaims.AuthTime, timeComparisonFudgeFactor) requestedAtZone, _ := actualClaims.RequestedAt.Zone() require.Equal(t, "UTC", requestedAtZone) authTimeZone, _ := actualClaims.AuthTime.Zone() require.Equal(t, "UTC", authTimeZone) // Fosite will set these fields for us in the token endpoint based on the store session // information. Therefore, we assert that they are empty because we want the library to do the // lifting for us. require.Empty(t, actualClaims.Issuer) require.Nil(t, actualClaims.Audience) require.Empty(t, actualClaims.Nonce) require.Zero(t, actualClaims.ExpiresAt) require.Zero(t, actualClaims.IssuedAt) // These are not needed yet. require.Empty(t, actualClaims.JTI) require.Empty(t, actualClaims.CodeHash) require.Empty(t, actualClaims.AccessTokenHash) require.Empty(t, actualClaims.AuthenticationContextClassReference) require.Empty(t, actualClaims.AuthenticationMethodsReferences) // Check that the custom Pinniped session data matches. require.Equal(t, wantCustomSessionData, storedSessionFromAuthcode.Custom) return storedRequestFromAuthcode, storedSessionFromAuthcode } func validatePKCEStorage( t *testing.T, oauthStore fositestoragei.AllFositeStorage, storeKey string, storedRequestFromAuthcode *fosite.Request, storedSessionFromAuthcode *psession.PinnipedSession, wantDownstreamPKCEChallenge, wantDownstreamPKCEChallengeMethod string, ) { t.Helper() storedAuthorizeRequestFromPKCE, err := oauthStore.GetPKCERequestSession(context.Background(), storeKey, nil) require.NoError(t, err) // Check that storage returned the expected concrete data types. storedRequestFromPKCE, storedSessionFromPKCE := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromPKCE) // The stored PKCE request should be the same as the stored authcode request. require.Equal(t, storedRequestFromAuthcode.ID, storedRequestFromPKCE.ID) require.Equal(t, storedSessionFromAuthcode, storedSessionFromPKCE) // The stored PKCE request should also contain the PKCE challenge that the downstream sent us. require.Equal(t, wantDownstreamPKCEChallenge, storedRequestFromPKCE.Form.Get("code_challenge")) require.Equal(t, wantDownstreamPKCEChallengeMethod, storedRequestFromPKCE.Form.Get("code_challenge_method")) } func validateIDSessionStorage( t *testing.T, oauthStore fositestoragei.AllFositeStorage, storeKey string, storedRequestFromAuthcode *fosite.Request, storedSessionFromAuthcode *psession.PinnipedSession, wantDownstreamNonce string, ) { t.Helper() storedAuthorizeRequestFromIDSession, err := oauthStore.GetOpenIDConnectSession(context.Background(), storeKey, nil) require.NoError(t, err) // Check that storage returned the expected concrete data types. storedRequestFromIDSession, storedSessionFromIDSession := castStoredAuthorizeRequest(t, storedAuthorizeRequestFromIDSession) // The stored IDSession request should be the same as the stored authcode request. require.Equal(t, storedRequestFromAuthcode.ID, storedRequestFromIDSession.ID) require.Equal(t, storedSessionFromAuthcode, storedSessionFromIDSession) // The stored IDSession request should also contain the nonce that the downstream sent us. require.Equal(t, wantDownstreamNonce, storedRequestFromIDSession.Form.Get("nonce")) } func castStoredAuthorizeRequest(t *testing.T, storedAuthorizeRequest fosite.Requester) (*fosite.Request, *psession.PinnipedSession) { t.Helper() storedRequest, ok := storedAuthorizeRequest.(*fosite.Request) require.Truef(t, ok, "could not cast %T to %T", storedAuthorizeRequest, &fosite.Request{}) storedSession, ok := storedAuthorizeRequest.GetSession().(*psession.PinnipedSession) require.Truef(t, ok, "could not cast %T to %T", storedAuthorizeRequest.GetSession(), &psession.PinnipedSession{}) return storedRequest, storedSession } // FilterClientSecretCreateActions ignores any reads made to get a storage secret corresponding to an OIDCClient, since these // are normal actions when the request is using a dynamic client's client_id, and we don't need to make assertions // about these Secrets since they are not related to session storage. func FilterClientSecretCreateActions(actions []kubetesting.Action) []kubetesting.Action { filtered := make([]kubetesting.Action, 0, len(actions)) for _, action := range actions { if action.Matches("get", "secrets") { getAction := action.(kubetesting.GetAction) if strings.HasPrefix(getAction.GetName(), "pinniped-storage-oidc-client-secret-") { continue // filter out OIDCClient's storage secret reads } } filtered = append(filtered, action) // otherwise include the action } return filtered }