From 8c3be3ffb2008b63f52f7a62d39921aaf9191c9e Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Fri, 4 Dec 2020 15:33:36 -0600 Subject: [PATCH] Refactor UpstreamOIDCIdentityProviderI claim handling. This refactors the `UpstreamOIDCIdentityProviderI` interface and its implementations to pass ID token claims through a `*oidctypes.Token` return parameter rather than as a third return parameter. Signed-off-by: Matt Moyer --- .../mockupstreamoidcidentityprovider.go | 18 +++++----- internal/oidc/callback/callback_handler.go | 6 ++-- .../oidc/callback/callback_handler_test.go | 7 ++-- internal/oidc/oidctestutil/oidc.go | 6 ++-- .../provider/dynamic_upstream_idp_provider.go | 4 +-- .../oidc/provider/manager/manager_test.go | 18 +++++----- internal/upstreamoidc/upstreamoidc.go | 21 ++++++------ internal/upstreamoidc/upstreamoidc_test.go | 33 +++++++++++-------- pkg/oidcclient/login.go | 10 ++---- pkg/oidcclient/login_test.go | 8 ++--- 10 files changed, 69 insertions(+), 62 deletions(-) diff --git a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go index 93085f4b..539f5727 100644 --- a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go +++ b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go @@ -43,13 +43,12 @@ func (m *MockUpstreamOIDCIdentityProviderI) EXPECT() *MockUpstreamOIDCIdentityPr } // ExchangeAuthcodeAndValidateTokens mocks base method -func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(arg0 context.Context, arg1 string, arg2 pkce.Code, arg3 nonce.Nonce, arg4 string) (oidctypes.Token, map[string]interface{}, error) { +func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(arg0 context.Context, arg1 string, arg2 pkce.Code, arg3 nonce.Nonce, arg4 string) (*oidctypes.Token, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ExchangeAuthcodeAndValidateTokens", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(oidctypes.Token) - ret1, _ := ret[1].(map[string]interface{}) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret0, _ := ret[0].(*oidctypes.Token) + ret1, _ := ret[1].(error) + return ret0, ret1 } // ExchangeAuthcodeAndValidateTokens indicates an expected call of ExchangeAuthcodeAndValidateTokens @@ -143,13 +142,12 @@ func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetUsernameClaim() *gom } // ValidateToken mocks base method -func (m *MockUpstreamOIDCIdentityProviderI) ValidateToken(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { +func (m *MockUpstreamOIDCIdentityProviderI) ValidateToken(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce) (*oidctypes.Token, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ValidateToken", arg0, arg1, arg2) - ret0, _ := ret[0].(oidctypes.Token) - ret1, _ := ret[1].(map[string]interface{}) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret0, _ := ret[0].(*oidctypes.Token) + ret1, _ := ret[1].(error) + return ret0, ret1 } // ValidateToken indicates an expected call of ValidateToken diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 4add765e..e102dc62 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -73,7 +73,7 @@ func NewHandler( // Grant the openid scope only if it was requested. grantOpenIDScopeIfRequested(authorizeRequester) - _, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens( + token, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens( r.Context(), authcode(r), state.PKCECode, @@ -85,12 +85,12 @@ func NewHandler( return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens") } - username, err := getUsernameFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims) + username, err := getUsernameFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims) if err != nil { return err } - groups, err := getGroupsFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims) + groups, err := getGroupsFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims) if err != nil { return err } diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index ead11693..d7e730fe 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -682,8 +682,11 @@ func (u *upstreamOIDCIdentityProviderBuilder) Build() oidctestutil.TestUpstreamO UsernameClaim: u.usernameClaim, GroupsClaim: u.groupsClaim, Scopes: []string{"scope1", "scope2"}, - ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { - return oidctypes.Token{}, u.idToken, u.authcodeExchangeErr + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { + if u.authcodeExchangeErr != nil { + return nil, u.authcodeExchangeErr + } + return &oidctypes.Token{IDToken: &oidctypes.IDToken{Claims: u.idToken}}, nil }, } } diff --git a/internal/oidc/oidctestutil/oidc.go b/internal/oidc/oidctestutil/oidc.go index 5b214e5c..bd293e65 100644 --- a/internal/oidc/oidctestutil/oidc.go +++ b/internal/oidc/oidctestutil/oidc.go @@ -39,7 +39,7 @@ type TestUpstreamOIDCIdentityProvider struct { authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, - ) (oidctypes.Token, map[string]interface{}, error) + ) (*oidctypes.Token, error) exchangeAuthcodeAndValidateTokensCallCount int exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs @@ -75,7 +75,7 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string, -) (oidctypes.Token, map[string]interface{}, error) { +) (*oidctypes.Token, error) { if u.exchangeAuthcodeAndValidateTokensArgs == nil { u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) } @@ -101,7 +101,7 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs return u.exchangeAuthcodeAndValidateTokensArgs[call] } -func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { +func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(_ context.Context, _ *oauth2.Token, _ nonce.Nonce) (*oidctypes.Token, error) { panic("implement me") } diff --git a/internal/oidc/provider/dynamic_upstream_idp_provider.go b/internal/oidc/provider/dynamic_upstream_idp_provider.go index be25ffe8..d5ab6f50 100644 --- a/internal/oidc/provider/dynamic_upstream_idp_provider.go +++ b/internal/oidc/provider/dynamic_upstream_idp_provider.go @@ -43,9 +43,9 @@ type UpstreamOIDCIdentityProviderI interface { pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string, - ) (tokens oidctypes.Token, parsedIDTokenClaims map[string]interface{}, err error) + ) (*oidctypes.Token, error) - ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) + ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) } type DynamicUpstreamIDPProvider interface { diff --git a/internal/oidc/provider/manager/manager_test.go b/internal/oidc/provider/manager/manager_test.go index a3f8090d..b4c7ecdf 100644 --- a/internal/oidc/provider/manager/manager_test.go +++ b/internal/oidc/provider/manager/manager_test.go @@ -172,15 +172,17 @@ func TestManager(t *testing.T) { ClientID: "test-client-id", AuthorizationURL: *parsedUpstreamIDPAuthorizationURL, Scopes: []string{"test-scope"}, - ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { - return oidctypes.Token{}, - map[string]interface{}{ - "iss": "https://some-issuer.com", - "sub": "some-subject", - "username": "test-username", - "groups": "test-group1", + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { + return &oidctypes.Token{ + IDToken: &oidctypes.IDToken{ + Claims: map[string]interface{}{ + "iss": "https://some-issuer.com", + "sub": "some-subject", + "username": "test-username", + "groups": "test-group1", + }, }, - nil + }, nil }, }) diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index a789cb85..af7c9a7c 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -61,7 +61,7 @@ func (p *ProviderConfig) GetGroupsClaim() string { return p.GroupsClaim } -func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string) (oidctypes.Token, map[string]interface{}, error) { +func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string) (*oidctypes.Token, error) { tok, err := p.Config.Exchange( oidc.ClientContext(ctx, p.Client), authcode, @@ -69,38 +69,38 @@ func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, oauth2.SetAuthURLParam("redirect_uri", redirectURI), ) if err != nil { - return oidctypes.Token{}, nil, err + return nil, err } return p.ValidateToken(ctx, tok, expectedIDTokenNonce) } -func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { +func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { idTok, hasIDTok := tok.Extra("id_token").(string) if !hasIDTok { - return oidctypes.Token{}, nil, httperr.New(http.StatusBadRequest, "received response missing ID token") + return nil, httperr.New(http.StatusBadRequest, "received response missing ID token") } validated, err := p.Provider.Verifier(&oidc.Config{ClientID: p.GetClientID()}).Verify(oidc.ClientContext(ctx, p.Client), idTok) if err != nil { - return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) + return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) } if validated.AccessTokenHash != "" { if err := validated.VerifyAccessToken(tok.AccessToken); err != nil { - return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) + return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) } } if expectedIDTokenNonce != "" { if err := expectedIDTokenNonce.Validate(validated); err != nil { - return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err) + return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err) } } var validatedClaims map[string]interface{} if err := validated.Claims(&validatedClaims); err != nil { - return oidctypes.Token{}, nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal claims", err) + return nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal claims", err) } - return oidctypes.Token{ + return &oidctypes.Token{ AccessToken: &oidctypes.AccessToken{ Token: tok.AccessToken, Type: tok.TokenType, @@ -112,6 +112,7 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e IDToken: &oidctypes.IDToken{ Token: idTok, Expiry: metav1.NewTime(validated.Expiry), + Claims: validatedClaims, }, - }, validatedClaims, nil + }, nil } diff --git a/internal/upstreamoidc/upstreamoidc_test.go b/internal/upstreamoidc/upstreamoidc_test.go index 541d502f..946f618b 100644 --- a/internal/upstreamoidc/upstreamoidc_test.go +++ b/internal/upstreamoidc/upstreamoidc_test.go @@ -63,7 +63,6 @@ func TestProviderConfig(t *testing.T) { returnIDTok string wantErr string wantToken oidctypes.Token - wantClaims map[string]interface{} }{ { name: "exchange fails with network error", @@ -110,6 +109,14 @@ func TestProviderConfig(t *testing.T) { IDToken: &oidctypes.IDToken{ Token: invalidNonceIDToken, Expiry: metav1.Time{}, + Claims: map[string]interface{}{ + "aud": "test-client-id", + "iat": 1.602283741e+09, + "jti": "test-jti", + "nbf": 1.602283741e+09, + "nonce": "invalid-nonce", + "sub": "test-user", + }, }, }, }, @@ -128,12 +135,17 @@ func TestProviderConfig(t *testing.T) { IDToken: &oidctypes.IDToken{ Token: validIDToken, Expiry: metav1.Time{}, + Claims: map[string]interface{}{ + "foo": "bar", + "bat": "baz", + "aud": "test-client-id", + "iat": 1.606768593e+09, + "jti": "test-jti", + "nbf": 1.606768593e+09, + "sub": "test-user", + }, }, }, - wantClaims: map[string]interface{}{ - "foo": "bar", - "bat": "baz", - }, }, } for _, tt := range tests { @@ -181,19 +193,14 @@ func TestProviderConfig(t *testing.T) { ctx := context.Background() - tok, claims, err := p.ExchangeAuthcodeAndValidateTokens(ctx, tt.authCode, "test-pkce", tt.expectNonce, "https://example.com/callback") + tok, err := p.ExchangeAuthcodeAndValidateTokens(ctx, tt.authCode, "test-pkce", tt.expectNonce, "https://example.com/callback") if tt.wantErr != "" { require.EqualError(t, err, tt.wantErr) - require.Equal(t, oidctypes.Token{}, tok) - require.Nil(t, claims) + require.Nil(t, tok) return } require.NoError(t, err) - require.Equal(t, tt.wantToken, tok) - - for k, v := range tt.wantClaims { - require.Equal(t, v, claims[k]) - } + require.Equal(t, &tt.wantToken, tok) }) } } diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index 0df34622..a5f0e86d 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -295,11 +295,7 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least // some providers do not include one, so we skip the nonce validation here (but not other validations). - token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).ValidateToken(ctx, refreshed, "") - if err != nil { - return nil, err - } - return &token, nil + return h.getProvider(h.oauth2Config, h.provider, h.httpClient).ValidateToken(ctx, refreshed, "") } func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) { @@ -328,7 +324,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req // Exchange the authorization code for access, ID, and refresh tokens and perform required // validations on the returned ID token. - token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient). + token, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient). ExchangeAuthcodeAndValidateTokens( r.Context(), params.Get("code"), @@ -340,7 +336,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) } - h.callbacks <- callbackResult{token: &token} + h.callbacks <- callbackResult{token: token} _, _ = w.Write([]byte("you have been logged in and may now close this tab")) return nil } diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 96d790ba..e86bc660 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -242,7 +242,7 @@ func TestLogin(t *testing.T) { mock := mockUpstream(t) mock.EXPECT(). ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). - Return(testToken, nil, nil) + Return(&testToken, nil) return mock } @@ -281,7 +281,7 @@ func TestLogin(t *testing.T) { mock := mockUpstream(t) mock.EXPECT(). ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). - Return(oidctypes.Token{}, nil, fmt.Errorf("some validation error")) + Return(nil, fmt.Errorf("some validation error")) return mock } @@ -529,7 +529,7 @@ func TestHandleAuthCodeCallback(t *testing.T) { mock := mockUpstream(t) mock.EXPECT(). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). - Return(oidctypes.Token{}, nil, fmt.Errorf("some exchange error")) + Return(nil, fmt.Errorf("some exchange error")) return mock } return nil @@ -546,7 +546,7 @@ func TestHandleAuthCodeCallback(t *testing.T) { mock := mockUpstream(t) mock.EXPECT(). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). - Return(oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil, nil) + Return(&oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil) return mock } return nil