diff --git a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go index e3887b82..93085f4b 100644 --- a/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go +++ b/internal/mocks/mockupstreamoidcidentityprovider/mockupstreamoidcidentityprovider.go @@ -43,9 +43,9 @@ func (m *MockUpstreamOIDCIdentityProviderI) EXPECT() *MockUpstreamOIDCIdentityPr } // ExchangeAuthcodeAndValidateTokens mocks base method -func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(arg0 context.Context, arg1 string, arg2 pkce.Code, arg3 nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { +func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(arg0 context.Context, arg1 string, arg2 pkce.Code, arg3 nonce.Nonce, arg4 string) (oidctypes.Token, map[string]interface{}, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ExchangeAuthcodeAndValidateTokens", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "ExchangeAuthcodeAndValidateTokens", arg0, arg1, arg2, arg3, arg4) ret0, _ := ret[0].(oidctypes.Token) ret1, _ := ret[1].(map[string]interface{}) ret2, _ := ret[2].(error) @@ -53,9 +53,9 @@ func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(ar } // ExchangeAuthcodeAndValidateTokens indicates an expected call of ExchangeAuthcodeAndValidateTokens -func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ExchangeAuthcodeAndValidateTokens(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ExchangeAuthcodeAndValidateTokens(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExchangeAuthcodeAndValidateTokens", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ExchangeAuthcodeAndValidateTokens), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExchangeAuthcodeAndValidateTokens", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ExchangeAuthcodeAndValidateTokens), arg0, arg1, arg2, arg3, arg4) } // GetAuthorizationURL mocks base method diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index f237726b..4add765e 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -42,6 +42,7 @@ func NewHandler( idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provider, stateDecoder, cookieDecoder oidc.Decoder, + redirectURI string, ) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { state, err := validateRequest(r, stateDecoder, cookieDecoder) @@ -77,6 +78,7 @@ func NewHandler( authcode(r), state.PKCECode, state.Nonce, + redirectURI, ) if err != nil { plog.WarningErr("error exchanging and validating upstream tokens", err, "upstreamName", upstreamIDPConfig.GetName()) diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index e73eb759..ead11693 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -43,6 +43,8 @@ const ( happyUpstreamAuthcode = "upstream-auth-code" + happyUpstreamRedirectURI = "https://example.com/callback" + happyDownstreamState = "some-downstream-state-with-at-least-32-bytes" happyDownstreamCSRF = "test-csrf" happyDownstreamPKCE = "test-pkce" @@ -105,6 +107,7 @@ func TestCallbackEndpoint(t *testing.T) { Authcode: happyUpstreamAuthcode, PKCECodeVerifier: pkce.Code(happyDownstreamPKCE), ExpectedIDTokenNonce: nonce.Nonce(happyDownstreamNonce), + RedirectURI: happyUpstreamRedirectURI, } // Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it @@ -433,7 +436,7 @@ func TestCallbackEndpoint(t *testing.T) { oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret) idpListGetter := oidctestutil.NewIDPListGetter(&test.idp) - subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec) + subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec, happyUpstreamRedirectURI) req := httptest.NewRequest(test.method, test.path, nil) if test.csrfCookie != "" { req.Header.Set("Cookie", test.csrfCookie) diff --git a/internal/oidc/oidctestutil/oidc.go b/internal/oidc/oidctestutil/oidc.go index eafd567f..5b214e5c 100644 --- a/internal/oidc/oidctestutil/oidc.go +++ b/internal/oidc/oidctestutil/oidc.go @@ -24,6 +24,7 @@ type ExchangeAuthcodeAndValidateTokenArgs struct { Authcode string PKCECodeVerifier pkce.Code ExpectedIDTokenNonce nonce.Nonce + RedirectURI string } type TestUpstreamOIDCIdentityProvider struct { @@ -73,6 +74,7 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, + redirectURI string, ) (oidctypes.Token, map[string]interface{}, error) { if u.exchangeAuthcodeAndValidateTokensArgs == nil { u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) @@ -83,6 +85,7 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( Authcode: authcode, PKCECodeVerifier: pkceCodeVerifier, ExpectedIDTokenNonce: expectedIDTokenNonce, + RedirectURI: redirectURI, }) return u.ExchangeAuthcodeAndValidateTokensFunc(ctx, authcode, pkceCodeVerifier, expectedIDTokenNonce) } diff --git a/internal/oidc/provider/dynamic_upstream_idp_provider.go b/internal/oidc/provider/dynamic_upstream_idp_provider.go index 8ef1e5db..be25ffe8 100644 --- a/internal/oidc/provider/dynamic_upstream_idp_provider.go +++ b/internal/oidc/provider/dynamic_upstream_idp_provider.go @@ -42,6 +42,7 @@ type UpstreamOIDCIdentityProviderI interface { authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, + redirectURI string, ) (tokens oidctypes.Token, parsedIDTokenClaims map[string]interface{}, err error) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index 42c828c1..6bac2c60 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -83,10 +83,25 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { encoder.SetSerializer(securecookie.JSONEncoder{}) authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath - m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder, encoder) + m.providerHandlers[authURL] = auth.NewHandler( + incomingProvider.Issuer(), + m.idpListGetter, + oauthHelper, + csrftoken.Generate, + pkce.Generate, + nonce.Generate, + encoder, + encoder, + ) callbackURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.CallbackEndpointPath - m.providerHandlers[callbackURL] = callback.NewHandler(m.idpListGetter, oauthHelper, encoder, encoder) + m.providerHandlers[callbackURL] = callback.NewHandler( + m.idpListGetter, + oauthHelper, + encoder, + encoder, + incomingProvider.Issuer()+oidc.CallbackEndpointPath, + ) plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer()) } diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index 4af7efdb..a789cb85 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -61,8 +61,13 @@ func (p *ProviderConfig) GetGroupsClaim() string { return p.GroupsClaim } -func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { - tok, err := p.Config.Exchange(oidc.ClientContext(ctx, p.Client), authcode, pkceCodeVerifier.Verifier()) +func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string) (oidctypes.Token, map[string]interface{}, error) { + tok, err := p.Config.Exchange( + oidc.ClientContext(ctx, p.Client), + authcode, + pkceCodeVerifier.Verifier(), + oauth2.SetAuthURLParam("redirect_uri", redirectURI), + ) if err != nil { return oidctypes.Token{}, nil, err } diff --git a/internal/upstreamoidc/upstreamoidc_test.go b/internal/upstreamoidc/upstreamoidc_test.go index 3f3eed2e..541d502f 100644 --- a/internal/upstreamoidc/upstreamoidc_test.go +++ b/internal/upstreamoidc/upstreamoidc_test.go @@ -181,7 +181,7 @@ func TestProviderConfig(t *testing.T) { ctx := context.Background() - tok, claims, err := p.ExchangeAuthcodeAndValidateTokens(ctx, tt.authCode, "test-pkce", tt.expectNonce) + tok, claims, err := p.ExchangeAuthcodeAndValidateTokens(ctx, tt.authCode, "test-pkce", tt.expectNonce, "https://example.com/callback") if tt.wantErr != "" { require.EqualError(t, err, tt.wantErr) require.Equal(t, oidctypes.Token{}, tok) diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index 2b21e080..0df34622 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -328,7 +328,14 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req // Exchange the authorization code for access, ID, and refresh tokens and perform required // validations on the returned ID token. - token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce) + token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient). + ExchangeAuthcodeAndValidateTokens( + r.Context(), + params.Get("code"), + h.pkce, + h.nonce, + h.oauth2Config.RedirectURL, + ) if err != nil { return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) } diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 374d90e3..96d790ba 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -488,6 +488,8 @@ func TestLogin(t *testing.T) { } func TestHandleAuthCodeCallback(t *testing.T) { + const testRedirectURI = "http://127.0.0.1:12324/callback" + tests := []struct { name string method string @@ -522,10 +524,11 @@ func TestHandleAuthCodeCallback(t *testing.T) { wantHTTPStatus: http.StatusBadRequest, opt: func(t *testing.T) Option { return func(h *handlerState) error { + h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). - ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")). + ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). Return(oidctypes.Token{}, nil, fmt.Errorf("some exchange error")) return mock } @@ -538,10 +541,11 @@ func TestHandleAuthCodeCallback(t *testing.T) { query: "state=test-state&code=valid", opt: func(t *testing.T) Option { return func(h *handlerState) error { + h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { mock := mockUpstream(t) mock.EXPECT(). - ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")). + ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). Return(oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil, nil) return mock }