diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index 08c42ada..8a566620 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -92,11 +92,18 @@ func NewHandler( Endpoint: oauth2.Endpoint{ AuthURL: upstreamIDP.GetAuthorizationURL().String(), }, - RedirectURL: fmt.Sprintf("%s/callback/%s", downstreamIssuer, upstreamIDP.GetName()), + RedirectURL: fmt.Sprintf("%s/callback", downstreamIssuer), Scopes: upstreamIDP.GetScopes(), } - encodedStateParamValue, err := upstreamStateParam(authorizeRequester, nonceValue, csrfValue, pkceValue, upstreamStateEncoder) + encodedStateParamValue, err := upstreamStateParam( + authorizeRequester, + upstreamIDP.GetName(), + nonceValue, + csrfValue, + pkceValue, + upstreamStateEncoder, + ) if err != nil { plog.Error("authorize upstream state param error", err) return err @@ -188,6 +195,7 @@ func generateValues( func upstreamStateParam( authorizeRequester fosite.AuthorizeRequester, + upstreamName string, nonceValue nonce.Nonce, csrfValue csrftoken.CSRFToken, pkceValue pkce.Code, @@ -195,6 +203,7 @@ func upstreamStateParam( ) (string, error) { stateParamData := oidc.UpstreamStateParamData{ AuthParams: authorizeRequester.GetRequestForm().Encode(), + UpstreamName: upstreamName, Nonce: nonceValue, CSRFToken: csrfValue, PKCECode: pkceValue, diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index 4003f9c2..381ff052 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -204,14 +204,19 @@ func TestAuthorizationEndpoint(t *testing.T) { return pathWithQuery("/some/path", modifiedHappyGetRequestQueryMap(queryOverrides)) } - expectedUpstreamStateParam := func(queryOverrides map[string]string, csrfValueOverride string) string { + expectedUpstreamStateParam := func(queryOverrides map[string]string, csrfValueOverride, upstreamNameOverride string) string { csrf := happyCSRF if csrfValueOverride != "" { csrf = csrfValueOverride } + upstreamName := upstreamOIDCIdentityProvider.Name + if upstreamNameOverride != "" { + upstreamName = upstreamNameOverride + } encoded, err := happyStateEncoder.Encode("s", oidctestutil.ExpectedUpstreamStateParamFormat{ P: encodeQuery(modifiedHappyGetRequestQueryMap(queryOverrides)), + U: upstreamName, N: happyNonce, C: csrf, K: happyPKCE, @@ -232,7 +237,7 @@ func TestAuthorizationEndpoint(t *testing.T) { "nonce": happyNonce, "code_challenge": expectedUpstreamCodeChallenge, "code_challenge_method": "S256", - "redirect_uri": downstreamIssuer + "/callback/some-idp", + "redirect_uri": downstreamIssuer + "/callback", }) } @@ -281,7 +286,7 @@ func TestAuthorizationEndpoint(t *testing.T) { wantStatus: http.StatusFound, wantContentType: "text/html; charset=utf-8", wantCSRFValueInCookieHeader: happyCSRF, - wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "")), + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, @@ -299,7 +304,7 @@ func TestAuthorizationEndpoint(t *testing.T) { csrfCookie: "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue, wantStatus: http.StatusFound, wantContentType: "text/html; charset=utf-8", - wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, incomingCookieCSRFValue)), + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, incomingCookieCSRFValue, "")), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, @@ -320,7 +325,7 @@ func TestAuthorizationEndpoint(t *testing.T) { wantContentType: "", wantBodyString: "", wantCSRFValueInCookieHeader: happyCSRF, - wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "")), + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")), wantUpstreamStateParamInLocationHeader: true, }, { @@ -341,7 +346,7 @@ func TestAuthorizationEndpoint(t *testing.T) { wantCSRFValueInCookieHeader: happyCSRF, wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{ "redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client - }, "")), + }, "", "")), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, @@ -538,7 +543,7 @@ func TestAuthorizationEndpoint(t *testing.T) { wantContentType: "text/html; charset=utf-8", wantCSRFValueInCookieHeader: happyCSRF, wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam( - map[string]string{"prompt": "none login", "scope": "email"}, "", + map[string]string{"prompt": "none login", "scope": "email"}, "", "", )), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, @@ -787,11 +792,11 @@ func TestAuthorizationEndpoint(t *testing.T) { "access_type": "offline", "scope": "other-scope1 other-scope2", "client_id": "some-other-client-id", - "state": expectedUpstreamStateParam(nil, ""), + "state": expectedUpstreamStateParam(nil, "", newProviderSettings.Name), "nonce": happyNonce, "code_challenge": expectedUpstreamCodeChallenge, "code_challenge_method": "S256", - "redirect_uri": downstreamIssuer + "/callback/some-other-idp", + "redirect_uri": downstreamIssuer + "/callback", }, ) test.wantBodyString = fmt.Sprintf(`Found.%s`, diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 017b9249..ee69ee13 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -66,6 +66,7 @@ type Codec interface { // the state param. type UpstreamStateParamData struct { AuthParams string `json:"p"` + UpstreamName string `json:"u"` Nonce nonce.Nonce `json:"n"` CSRFToken csrftoken.CSRFToken `json:"c"` PKCECode pkce.Code `json:"k"` diff --git a/internal/oidc/oidctestutil/oidc.go b/internal/oidc/oidctestutil/oidc.go index ad4338db..fc8c3092 100644 --- a/internal/oidc/oidctestutil/oidc.go +++ b/internal/oidc/oidctestutil/oidc.go @@ -112,6 +112,7 @@ func NewIDPListGetter(upstreamOIDCIdentityProviders ...*TestUpstreamOIDCIdentity // assertions about the redirect URL in this test. type ExpectedUpstreamStateParamFormat struct { P string `json:"p"` + U string `json:"u"` N string `json:"n"` C string `json:"c"` K string `json:"k"`