diff --git a/internal/oidc/auth/auth_handler.go b/internal/oidc/auth/auth_handler.go index 36e42b25..5bfe7947 100644 --- a/internal/oidc/auth/auth_handler.go +++ b/internal/oidc/auth/auth_handler.go @@ -121,13 +121,22 @@ func NewHandler( } } + authCodeOptions := []oauth2.AuthCodeOption{ + oauth2.AccessTypeOffline, + nonceValue.Param(), + pkceValue.Challenge(), + pkceValue.Method(), + } + + promptParam := r.Form.Get("prompt") + if promptParam != "" && oidc.ScopeWasRequested(authorizeRequester, coreosoidc.ScopeOpenID) { + authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("prompt", promptParam)) + } + http.Redirect(w, r, upstreamOAuthConfig.AuthCodeURL( encodedStateParamValue, - oauth2.AccessTypeOffline, - nonceValue.Param(), - pkceValue.Challenge(), - pkceValue.Method(), + authCodeOptions..., ), 302, ) diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index d9dc051c..b937a228 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -57,8 +57,8 @@ func TestAuthorizationEndpoint(t *testing.T) { fositePromptHasNoneAndOtherValueErrorQuery = map[string]string{ "error": "invalid_request", - "error_description": "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed\n\nUsed unknown value \"[none login]\" for prompt parameter", - "error_hint": "Used unknown value \"[none login]\" for prompt parameter", + "error_description": "The request is missing a required parameter, includes an invalid parameter value, includes a parameter more than once, or is otherwise malformed\n\nParameter \"prompt\" was set to \"none\", but contains other values as well which is not allowed.", + "error_hint": "Parameter \"prompt\" was set to \"none\", but contains other values as well which is not allowed.", "state": "some-state-value-that-is-32-byte", } @@ -99,8 +99,8 @@ func TestAuthorizationEndpoint(t *testing.T) { fositeInvalidStateErrorQuery = map[string]string{ "error": "invalid_state", - "error_description": "The state is missing or does not have enough characters and is therefore considered too weak\n\nRequest parameter \"state\" must be at least be 32 characters long to ensure sufficient entropy.", - "error_hint": `Request parameter "state" must be at least be 32 characters long to ensure sufficient entropy.`, + "error_description": "The state is missing or does not have enough characters and is therefore considered too weak\n\nRequest parameter \"state\" must be at least be 8 characters long to ensure sufficient entropy.", + "error_hint": `Request parameter "state" must be at least be 8 characters long to ensure sufficient entropy.`, "state": "short", } @@ -229,8 +229,8 @@ func TestAuthorizationEndpoint(t *testing.T) { return encoded } - expectedRedirectLocation := func(expectedUpstreamState string) string { - return urlWithQuery(upstreamAuthURL.String(), map[string]string{ + expectedRedirectLocation := func(expectedUpstreamState string, expectedPrompt string) string { + query := map[string]string{ "response_type": "code", "access_type": "offline", "scope": "scope1 scope2", @@ -240,7 +240,11 @@ func TestAuthorizationEndpoint(t *testing.T) { "code_challenge": expectedUpstreamCodeChallenge, "code_challenge_method": "S256", "redirect_uri": downstreamIssuer + "/callback", - }) + } + if expectedPrompt != "" { + query["prompt"] = expectedPrompt + } + return urlWithQuery(upstreamAuthURL.String(), query) } incomingCookieCSRFValue := "csrf-value-from-cookie" @@ -288,7 +292,7 @@ func TestAuthorizationEndpoint(t *testing.T) { wantStatus: http.StatusFound, wantContentType: "text/html; charset=utf-8", wantCSRFValueInCookieHeader: happyCSRF, - wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")), + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", ""), ""), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, @@ -306,7 +310,7 @@ func TestAuthorizationEndpoint(t *testing.T) { csrfCookie: "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue + " ", wantStatus: http.StatusFound, wantContentType: "text/html; charset=utf-8", - wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, incomingCookieCSRFValue, "")), + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, incomingCookieCSRFValue, ""), ""), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, @@ -327,7 +331,27 @@ func TestAuthorizationEndpoint(t *testing.T) { wantContentType: "", wantBodyString: "", wantCSRFValueInCookieHeader: happyCSRF, - wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")), + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", ""), ""), + wantUpstreamStateParamInLocationHeader: true, + }, + { + name: "happy path with prompt param login passed through to redirect uri", + issuer: downstreamIssuer, + idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), + generateCSRF: happyCSRFGenerator, + generatePKCE: happyPKCEGenerator, + generateNonce: happyNonceGenerator, + stateEncoder: happyStateEncoder, + cookieEncoder: happyCookieEncoder, + method: http.MethodGet, + path: modifiedHappyGetRequestPath(map[string]string{"prompt": "login"}), + contentType: "application/x-www-form-urlencoded", + body: encodeQuery(happyGetRequestQueryMap), + wantStatus: http.StatusFound, + wantContentType: "text/html; charset=utf-8", + wantBodyStringWithLocationInHref: true, + wantCSRFValueInCookieHeader: happyCSRF, + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{"prompt": "login"}, "", ""), "login"), wantUpstreamStateParamInLocationHeader: true, }, { @@ -346,7 +370,7 @@ func TestAuthorizationEndpoint(t *testing.T) { wantContentType: "text/html; charset=utf-8", // Generated a new CSRF cookie and set it in the response. wantCSRFValueInCookieHeader: happyCSRF, - wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", "")), + wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil, "", ""), ""), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, @@ -368,7 +392,7 @@ func TestAuthorizationEndpoint(t *testing.T) { wantCSRFValueInCookieHeader: happyCSRF, wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{ "redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client - }, "", "")), + }, "", ""), ""), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, @@ -388,7 +412,7 @@ func TestAuthorizationEndpoint(t *testing.T) { wantCSRFValueInCookieHeader: happyCSRF, wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{ "scope": "openid offline_access", - }, "", "")), + }, "", ""), ""), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, @@ -586,7 +610,7 @@ func TestAuthorizationEndpoint(t *testing.T) { wantCSRFValueInCookieHeader: happyCSRF, wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam( map[string]string{"prompt": "none login", "scope": "email"}, "", "", - )), + ), ""), wantUpstreamStateParamInLocationHeader: true, wantBodyStringWithLocationInHref: true, }, diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 69ee1dfa..a1cb56fe 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -259,9 +259,16 @@ type IDPListGetter interface { } func GrantScopeIfRequested(authorizeRequester fosite.AuthorizeRequester, scopeName string) { - for _, scope := range authorizeRequester.GetRequestedScopes() { - if scope == scopeName { - authorizeRequester.GrantScope(scope) - } + if ScopeWasRequested(authorizeRequester, scopeName) { + authorizeRequester.GrantScope(scopeName) } } + +func ScopeWasRequested(authorizeRequester fosite.AuthorizeRequester, scopeName string) bool { + for _, scope := range authorizeRequester.GetRequestedScopes() { + if scope == scopeName { + return true + } + } + return false +}