diff --git a/cmd/pinniped/cmd/login_oidc.go b/cmd/pinniped/cmd/login_oidc.go index c8d00662..921f6e60 100644 --- a/cmd/pinniped/cmd/login_oidc.go +++ b/cmd/pinniped/cmd/login_oidc.go @@ -48,7 +48,7 @@ func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oid cmd.Flags().StringVar(&issuer, "issuer", "", "OpenID Connect issuer URL.") cmd.Flags().StringVar(&clientID, "client-id", "", "OpenID Connect client ID.") cmd.Flags().Uint16Var(&listenPort, "listen-port", 0, "TCP port for localhost listener (authorization code flow only).") - cmd.Flags().StringSliceVar(&scopes, "scopes", []string{"offline_access", "openid", "email", "profile"}, "OIDC scopes to request during login.") + cmd.Flags().StringSliceVar(&scopes, "scopes", []string{"offline_access", "openid"}, "OIDC scopes to request during login.") cmd.Flags().BoolVar(&skipBrowser, "skip-browser", false, "Skip opening the browser (just print the URL).") cmd.Flags().StringVar(&sessionCachePath, "session-cache", filepath.Join(mustGetConfigDir(), "sessions.yaml"), "Path to session cache file.") cmd.Flags().StringSliceVar(&caBundlePaths, "ca-bundle", nil, "Path to TLS certificate authority bundle (PEM format, optional, can be repeated).") diff --git a/cmd/pinniped/cmd/login_oidc_test.go b/cmd/pinniped/cmd/login_oidc_test.go index 37cfac4e..64d0902e 100644 --- a/cmd/pinniped/cmd/login_oidc_test.go +++ b/cmd/pinniped/cmd/login_oidc_test.go @@ -46,7 +46,7 @@ func TestLoginOIDCCommand(t *testing.T) { -h, --help help for oidc --issuer string OpenID Connect issuer URL. --listen-port uint16 TCP port for localhost listener (authorization code flow only). - --scopes strings OIDC scopes to request during login. (default [offline_access,openid,email,profile]) + --scopes strings OIDC scopes to request during login. (default [offline_access,openid]) --session-cache string Path to session cache file. (default "` + cfgDir + `/sessions.yaml") --skip-browser Skip opening the browser (just print the URL). `), diff --git a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go index 93085f4b..539f5727 100644 --- a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go +++ b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go @@ -43,13 +43,12 @@ func (m *MockUpstreamOIDCIdentityProviderI) EXPECT() *MockUpstreamOIDCIdentityPr } // ExchangeAuthcodeAndValidateTokens mocks base method -func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(arg0 context.Context, arg1 string, arg2 pkce.Code, arg3 nonce.Nonce, arg4 string) (oidctypes.Token, map[string]interface{}, error) { +func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(arg0 context.Context, arg1 string, arg2 pkce.Code, arg3 nonce.Nonce, arg4 string) (*oidctypes.Token, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ExchangeAuthcodeAndValidateTokens", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].(oidctypes.Token) - ret1, _ := ret[1].(map[string]interface{}) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret0, _ := ret[0].(*oidctypes.Token) + ret1, _ := ret[1].(error) + return ret0, ret1 } // ExchangeAuthcodeAndValidateTokens indicates an expected call of ExchangeAuthcodeAndValidateTokens @@ -143,13 +142,12 @@ func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) GetUsernameClaim() *gom } // ValidateToken mocks base method -func (m *MockUpstreamOIDCIdentityProviderI) ValidateToken(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { +func (m *MockUpstreamOIDCIdentityProviderI) ValidateToken(arg0 context.Context, arg1 *oauth2.Token, arg2 nonce.Nonce) (*oidctypes.Token, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "ValidateToken", arg0, arg1, arg2) - ret0, _ := ret[0].(oidctypes.Token) - ret1, _ := ret[1].(map[string]interface{}) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret0, _ := ret[0].(*oidctypes.Token) + ret1, _ := ret[1].(error) + return ret0, ret1 } // ValidateToken indicates an expected call of ValidateToken diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 4add765e..e102dc62 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -73,7 +73,7 @@ func NewHandler( // Grant the openid scope only if it was requested. grantOpenIDScopeIfRequested(authorizeRequester) - _, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens( + token, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens( r.Context(), authcode(r), state.PKCECode, @@ -85,12 +85,12 @@ func NewHandler( return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens") } - username, err := getUsernameFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims) + username, err := getUsernameFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims) if err != nil { return err } - groups, err := getGroupsFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims) + groups, err := getGroupsFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims) if err != nil { return err } diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 5a0e8c6c..ed4c6b01 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -676,8 +676,11 @@ func (u *upstreamOIDCIdentityProviderBuilder) Build() oidctestutil.TestUpstreamO UsernameClaim: u.usernameClaim, GroupsClaim: u.groupsClaim, Scopes: []string{"scope1", "scope2"}, - ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier oidcpkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { - return oidctypes.Token{}, u.idToken, u.authcodeExchangeErr + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier oidcpkce.Code, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { + if u.authcodeExchangeErr != nil { + return nil, u.authcodeExchangeErr + } + return &oidctypes.Token{IDToken: &oidctypes.IDToken{Claims: u.idToken}}, nil }, } } diff --git a/internal/oidc/oidctestutil/oidc.go b/internal/oidc/oidctestutil/oidc.go index 19aa7c07..09af7236 100644 --- a/internal/oidc/oidctestutil/oidc.go +++ b/internal/oidc/oidctestutil/oidc.go @@ -46,7 +46,7 @@ type TestUpstreamOIDCIdentityProvider struct { authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, - ) (oidctypes.Token, map[string]interface{}, error) + ) (*oidctypes.Token, error) exchangeAuthcodeAndValidateTokensCallCount int exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs @@ -82,7 +82,7 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string, -) (oidctypes.Token, map[string]interface{}, error) { +) (*oidctypes.Token, error) { if u.exchangeAuthcodeAndValidateTokensArgs == nil { u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) } @@ -108,7 +108,7 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs return u.exchangeAuthcodeAndValidateTokensArgs[call] } -func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { +func (u *TestUpstreamOIDCIdentityProvider) ValidateToken(_ context.Context, _ *oauth2.Token, _ nonce.Nonce) (*oidctypes.Token, error) { panic("implement me") } diff --git a/internal/oidc/provider/dynamic_upstream_idp_provider.go b/internal/oidc/provider/dynamic_upstream_idp_provider.go index be25ffe8..d5ab6f50 100644 --- a/internal/oidc/provider/dynamic_upstream_idp_provider.go +++ b/internal/oidc/provider/dynamic_upstream_idp_provider.go @@ -43,9 +43,9 @@ type UpstreamOIDCIdentityProviderI interface { pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string, - ) (tokens oidctypes.Token, parsedIDTokenClaims map[string]interface{}, err error) + ) (*oidctypes.Token, error) - ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) + ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) } type DynamicUpstreamIDPProvider interface { diff --git a/internal/oidc/provider/manager/manager_test.go b/internal/oidc/provider/manager/manager_test.go index 2cf9276d..ed048a31 100644 --- a/internal/oidc/provider/manager/manager_test.go +++ b/internal/oidc/provider/manager/manager_test.go @@ -227,15 +227,17 @@ func TestManager(t *testing.T) { ClientID: "test-client-id", AuthorizationURL: *parsedUpstreamIDPAuthorizationURL, Scopes: []string{"test-scope"}, - ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { - return oidctypes.Token{}, - map[string]interface{}{ - "iss": "https://some-issuer.com", - "sub": "some-subject", - "username": "test-username", - "groups": "test-group1", + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { + return &oidctypes.Token{ + IDToken: &oidctypes.IDToken{ + Claims: map[string]interface{}{ + "iss": "https://some-issuer.com", + "sub": "some-subject", + "username": "test-username", + "groups": "test-group1", + }, }, - nil + }, nil }, }) diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index a789cb85..af7c9a7c 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -61,7 +61,7 @@ func (p *ProviderConfig) GetGroupsClaim() string { return p.GroupsClaim } -func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string) (oidctypes.Token, map[string]interface{}, error) { +func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string) (*oidctypes.Token, error) { tok, err := p.Config.Exchange( oidc.ClientContext(ctx, p.Client), authcode, @@ -69,38 +69,38 @@ func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, oauth2.SetAuthURLParam("redirect_uri", redirectURI), ) if err != nil { - return oidctypes.Token{}, nil, err + return nil, err } return p.ValidateToken(ctx, tok, expectedIDTokenNonce) } -func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { +func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { idTok, hasIDTok := tok.Extra("id_token").(string) if !hasIDTok { - return oidctypes.Token{}, nil, httperr.New(http.StatusBadRequest, "received response missing ID token") + return nil, httperr.New(http.StatusBadRequest, "received response missing ID token") } validated, err := p.Provider.Verifier(&oidc.Config{ClientID: p.GetClientID()}).Verify(oidc.ClientContext(ctx, p.Client), idTok) if err != nil { - return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) + return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) } if validated.AccessTokenHash != "" { if err := validated.VerifyAccessToken(tok.AccessToken); err != nil { - return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) + return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) } } if expectedIDTokenNonce != "" { if err := expectedIDTokenNonce.Validate(validated); err != nil { - return oidctypes.Token{}, nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err) + return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err) } } var validatedClaims map[string]interface{} if err := validated.Claims(&validatedClaims); err != nil { - return oidctypes.Token{}, nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal claims", err) + return nil, httperr.Wrap(http.StatusInternalServerError, "could not unmarshal claims", err) } - return oidctypes.Token{ + return &oidctypes.Token{ AccessToken: &oidctypes.AccessToken{ Token: tok.AccessToken, Type: tok.TokenType, @@ -112,6 +112,7 @@ func (p *ProviderConfig) ValidateToken(ctx context.Context, tok *oauth2.Token, e IDToken: &oidctypes.IDToken{ Token: idTok, Expiry: metav1.NewTime(validated.Expiry), + Claims: validatedClaims, }, - }, validatedClaims, nil + }, nil } diff --git a/internal/upstreamoidc/upstreamoidc_test.go b/internal/upstreamoidc/upstreamoidc_test.go index 541d502f..946f618b 100644 --- a/internal/upstreamoidc/upstreamoidc_test.go +++ b/internal/upstreamoidc/upstreamoidc_test.go @@ -63,7 +63,6 @@ func TestProviderConfig(t *testing.T) { returnIDTok string wantErr string wantToken oidctypes.Token - wantClaims map[string]interface{} }{ { name: "exchange fails with network error", @@ -110,6 +109,14 @@ func TestProviderConfig(t *testing.T) { IDToken: &oidctypes.IDToken{ Token: invalidNonceIDToken, Expiry: metav1.Time{}, + Claims: map[string]interface{}{ + "aud": "test-client-id", + "iat": 1.602283741e+09, + "jti": "test-jti", + "nbf": 1.602283741e+09, + "nonce": "invalid-nonce", + "sub": "test-user", + }, }, }, }, @@ -128,12 +135,17 @@ func TestProviderConfig(t *testing.T) { IDToken: &oidctypes.IDToken{ Token: validIDToken, Expiry: metav1.Time{}, + Claims: map[string]interface{}{ + "foo": "bar", + "bat": "baz", + "aud": "test-client-id", + "iat": 1.606768593e+09, + "jti": "test-jti", + "nbf": 1.606768593e+09, + "sub": "test-user", + }, }, }, - wantClaims: map[string]interface{}{ - "foo": "bar", - "bat": "baz", - }, }, } for _, tt := range tests { @@ -181,19 +193,14 @@ func TestProviderConfig(t *testing.T) { ctx := context.Background() - tok, claims, err := p.ExchangeAuthcodeAndValidateTokens(ctx, tt.authCode, "test-pkce", tt.expectNonce, "https://example.com/callback") + tok, err := p.ExchangeAuthcodeAndValidateTokens(ctx, tt.authCode, "test-pkce", tt.expectNonce, "https://example.com/callback") if tt.wantErr != "" { require.EqualError(t, err, tt.wantErr) - require.Equal(t, oidctypes.Token{}, tok) - require.Nil(t, claims) + require.Nil(t, tok) return } require.NoError(t, err) - require.Equal(t, tt.wantToken, tok) - - for k, v := range tt.wantClaims { - require.Equal(t, v, claims[k]) - } + require.Equal(t, &tt.wantToken, tok) }) } } diff --git a/pkg/oidcclient/filesession/cachefile_test.go b/pkg/oidcclient/filesession/cachefile_test.go index b1e1c984..39ac87fb 100644 --- a/pkg/oidcclient/filesession/cachefile_test.go +++ b/pkg/oidcclient/filesession/cachefile_test.go @@ -38,6 +38,13 @@ var validSession = sessionCache{ IDToken: &oidctypes.IDToken{ Token: "test-id-token", Expiry: metav1.NewTime(time.Date(2020, 10, 20, 19, 42, 07, 0, time.UTC).Local()), + Claims: map[string]interface{}{ + "foo": "bar", + "nested": map[string]interface{}{ + "key1": "value1", + "key2": "value2", + }, + }, }, RefreshToken: &oidctypes.RefreshToken{ Token: "test-refresh-token", diff --git a/pkg/oidcclient/filesession/testdata/valid.yaml b/pkg/oidcclient/filesession/testdata/valid.yaml index 67602c7b..e0d1943c 100644 --- a/pkg/oidcclient/filesession/testdata/valid.yaml +++ b/pkg/oidcclient/filesession/testdata/valid.yaml @@ -20,5 +20,10 @@ sessions: id: expiryTimestamp: "2020-10-20T19:42:07Z" token: test-id-token + claims: + foo: bar + nested: + key1: value1 + key2: value2 refresh: token: test-refresh-token diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index 0df34622..a5f0e86d 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -295,11 +295,7 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least // some providers do not include one, so we skip the nonce validation here (but not other validations). - token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).ValidateToken(ctx, refreshed, "") - if err != nil { - return nil, err - } - return &token, nil + return h.getProvider(h.oauth2Config, h.provider, h.httpClient).ValidateToken(ctx, refreshed, "") } func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) { @@ -328,7 +324,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req // Exchange the authorization code for access, ID, and refresh tokens and perform required // validations on the returned ID token. - token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient). + token, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient). ExchangeAuthcodeAndValidateTokens( r.Context(), params.Get("code"), @@ -340,7 +336,7 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) } - h.callbacks <- callbackResult{token: &token} + h.callbacks <- callbackResult{token: token} _, _ = w.Write([]byte("you have been logged in and may now close this tab")) return nil } diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 96d790ba..e86bc660 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -242,7 +242,7 @@ func TestLogin(t *testing.T) { mock := mockUpstream(t) mock.EXPECT(). ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). - Return(testToken, nil, nil) + Return(&testToken, nil) return mock } @@ -281,7 +281,7 @@ func TestLogin(t *testing.T) { mock := mockUpstream(t) mock.EXPECT(). ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")). - Return(oidctypes.Token{}, nil, fmt.Errorf("some validation error")) + Return(nil, fmt.Errorf("some validation error")) return mock } @@ -529,7 +529,7 @@ func TestHandleAuthCodeCallback(t *testing.T) { mock := mockUpstream(t) mock.EXPECT(). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). - Return(oidctypes.Token{}, nil, fmt.Errorf("some exchange error")) + Return(nil, fmt.Errorf("some exchange error")) return mock } return nil @@ -546,7 +546,7 @@ func TestHandleAuthCodeCallback(t *testing.T) { mock := mockUpstream(t) mock.EXPECT(). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). - Return(oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil, nil) + Return(&oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil) return mock } return nil diff --git a/pkg/oidcclient/oidctypes/oidctypes.go b/pkg/oidcclient/oidctypes/oidctypes.go index 94f5dcc9..d3d1b658 100644 --- a/pkg/oidcclient/oidctypes/oidctypes.go +++ b/pkg/oidcclient/oidctypes/oidctypes.go @@ -31,6 +31,9 @@ type IDToken struct { // Expiry is the optional expiration time of the ID token. Expiry v1.Time `json:"expiryTimestamp,omitempty"` + + // Claims are the claims expressed by the Token. + Claims map[string]interface{} `json:"claims,omitempty"` } // Token contains the elements of an OIDC session. diff --git a/test/integration/cli_test.go b/test/integration/cli_test.go index c4c051eb..cc51904e 100644 --- a/test/integration/cli_test.go +++ b/test/integration/cli_test.go @@ -313,6 +313,7 @@ func oidcLoginCommand(ctx context.Context, t *testing.T, pinnipedExe string, ses cmd := exec.CommandContext(ctx, pinnipedExe, "login", "oidc", "--issuer", env.CLITestUpstream.Issuer, "--client-id", env.CLITestUpstream.ClientID, + "--scopes", "offline_access,openid,email,profile", "--listen-port", callbackURL.Port(), "--session-cache", sessionCachePath, "--skip-browser",