diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index fdf34eff..f354b94a 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -87,6 +87,7 @@ type handlerState struct { // Generated parameters of a login flow. provider *oidc.Provider oauth2Config *oauth2.Config + useFormPost bool state state.State nonce nonce.Nonce pkce pkce.Code @@ -486,8 +487,14 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp Path: h.callbackPath, }).String() + // If the server supports it, request response_mode=form_post. + authParams := *authorizeOptions + if h.useFormPost { + authParams = append(authParams, oauth2.SetAuthURLParam("response_mode", "form_post")) + } + // Now that we have a redirect URL with the listener port, we can build the authorize URL. - authorizeURL := h.oauth2Config.AuthCodeURL(h.state.String(), *authorizeOptions...) + authorizeURL := h.oauth2Config.AuthCodeURL(h.state.String(), authParams...) // Start a callback server in a background goroutine. shutdown := h.serve(listener) @@ -567,9 +574,27 @@ func (h *handlerState) initOIDCDiscovery() error { Endpoint: h.provider.Endpoint(), Scopes: h.scopes, } + + // Use response_mode=form_post if the provider supports it. + var discoveryClaims struct { + ResponseModesSupported []string `json:"response_modes_supported"` + } + if err := h.provider.Claims(&discoveryClaims); err != nil { + return fmt.Errorf("could not decode response_modes_supported in OIDC discovery from %q: %w", h.issuer, err) + } + h.useFormPost = stringSliceContains(discoveryClaims.ResponseModesSupported, "form_post") return nil } +func stringSliceContains(slice []string, s string) bool { + for _, item := range slice { + if item == s { + return true + } + } + return false +} + func (h *handlerState) tokenExchangeRFC8693(baseToken *oidctypes.Token) (*oidctypes.Token, error) { h.logger.V(debugLogLevel).Info("Pinniped: Performing RFC8693 token exchange", "requestedAudience", h.requestedAudience) // Perform OIDC discovery. This may have already been performed if there was not a cached base token. diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index d85c01ac..31d84def 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -80,6 +80,22 @@ func TestLogin(t *testing.T) { // nolint:gocyclo })) t.Cleanup(errorServer.Close) + // Start a test server that returns discovery data with a broken response_modes_supported value. + brokenResponseModeMux := http.NewServeMux() + brokenResponseModeServer := httptest.NewServer(brokenResponseModeMux) + brokenResponseModeMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("content-type", "application/json") + type providerJSON struct { + Issuer string `json:"issuer"` + ResponseModesSupported string `json:"response_modes_supported"` // Wrong type (should be []string). + } + _ = json.NewEncoder(w).Encode(&providerJSON{ + Issuer: brokenResponseModeServer.URL, + ResponseModesSupported: "invalid", + }) + }) + t.Cleanup(brokenResponseModeServer.Close) + // Start a test server that returns discovery data with a broken token URL brokenTokenURLMux := http.NewServeMux() brokenTokenURLServer := httptest.NewServer(brokenTokenURLMux) @@ -100,30 +116,29 @@ func TestLogin(t *testing.T) { // nolint:gocyclo }) t.Cleanup(brokenTokenURLServer.Close) - // Start a test server that returns a real discovery document and answers refresh requests. - providerMux := http.NewServeMux() - successServer := httptest.NewServer(providerMux) - t.Cleanup(successServer.Close) - providerMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodGet { - http.Error(w, "unexpected method", http.StatusMethodNotAllowed) - return + discoveryHandler := func(server *httptest.Server, responseModes []string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "unexpected method", http.StatusMethodNotAllowed) + return + } + w.Header().Set("content-type", "application/json") + _ = json.NewEncoder(w).Encode(&struct { + Issuer string `json:"issuer"` + AuthURL string `json:"authorization_endpoint"` + TokenURL string `json:"token_endpoint"` + JWKSURL string `json:"jwks_uri"` + ResponseModesSupported []string `json:"response_modes_supported,omitempty"` + }{ + Issuer: server.URL, + AuthURL: server.URL + "/authorize", + TokenURL: server.URL + "/token", + JWKSURL: server.URL + "/keys", + ResponseModesSupported: responseModes, + }) } - w.Header().Set("content-type", "application/json") - type providerJSON struct { - Issuer string `json:"issuer"` - AuthURL string `json:"authorization_endpoint"` - TokenURL string `json:"token_endpoint"` - JWKSURL string `json:"jwks_uri"` - } - _ = json.NewEncoder(w).Encode(&providerJSON{ - Issuer: successServer.URL, - AuthURL: successServer.URL + "/authorize", - TokenURL: successServer.URL + "/token", - JWKSURL: successServer.URL + "/keys", - }) - }) - providerMux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + } + tokenHandler := func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "unexpected method", http.StatusMethodNotAllowed) return @@ -204,7 +219,21 @@ func TestLogin(t *testing.T) { // nolint:gocyclo w.Header().Set("content-type", "application/json") require.NoError(t, json.NewEncoder(w).Encode(&response)) - }) + } + + // Start a test server that returns a real discovery document and answers refresh requests. + providerMux := http.NewServeMux() + successServer := httptest.NewServer(providerMux) + t.Cleanup(successServer.Close) + providerMux.HandleFunc("/.well-known/openid-configuration", discoveryHandler(successServer, nil)) + providerMux.HandleFunc("/token", tokenHandler) + + // Start a test server that returns a real discovery document and answers refresh requests, _and_ supports form_mode=post. + formPostProviderMux := http.NewServeMux() + formPostSuccessServer := httptest.NewServer(formPostProviderMux) + t.Cleanup(formPostSuccessServer.Close) + formPostProviderMux.HandleFunc("/.well-known/openid-configuration", discoveryHandler(formPostSuccessServer, []string{"query", "form_post"})) + formPostProviderMux.HandleFunc("/token", tokenHandler) defaultDiscoveryResponse := func(req *http.Request) (*http.Response, error) { // nolint:unparam // Call the handler function from the test server to calculate the response. @@ -349,7 +378,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo wantToken: &testToken, }, { - name: "discovery failure", + name: "discovery failure due to 500 error", opt: func(t *testing.T) Option { return func(h *handlerState) error { return nil } }, @@ -357,6 +386,15 @@ func TestLogin(t *testing.T) { // nolint:gocyclo wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + errorServer.URL + "\""}, wantErr: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL), }, + { + name: "discovery failure due to invalid response_modes_supported", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { return nil } + }, + issuer: brokenResponseModeServer.URL, + wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + brokenResponseModeServer.URL + "\""}, + wantErr: fmt.Sprintf("could not decode response_modes_supported in OIDC discovery from %q: json: cannot unmarshal string into Go struct field .response_modes_supported of type []string", brokenResponseModeServer.URL), + }, { name: "session cache hit with refreshable token", issuer: successServer.URL, @@ -580,6 +618,68 @@ func TestLogin(t *testing.T) { // nolint:gocyclo wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantToken: &testToken, }, + { + name: "callback returns success with request_mode=form_post", + clientID: "test-client-id", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.generateState = func() (state.State, error) { return "test-state", nil } + h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } + h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } + + cache := &mockSessionCache{t: t, getReturnsToken: nil} + cacheKey := SessionCacheKey{ + Issuer: formPostSuccessServer.URL, + ClientID: "test-client-id", + Scopes: []string{"test-scope"}, + RedirectURI: "http://localhost:0/callback", + } + t.Cleanup(func() { + require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys) + require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys) + require.Equal(t, []*oidctypes.Token{&testToken}, cache.sawPutTokens) + }) + require.NoError(t, WithSessionCache(cache)(h)) + require.NoError(t, WithClient(&http.Client{Timeout: 10 * time.Second})(h)) + + h.openURL = func(actualURL string) error { + parsedActualURL, err := url.Parse(actualURL) + require.NoError(t, err) + actualParams := parsedActualURL.Query() + + require.Contains(t, actualParams.Get("redirect_uri"), "http://127.0.0.1:") + actualParams.Del("redirect_uri") + + require.Equal(t, url.Values{ + // This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example: + // $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1 + // VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g + "code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"}, + "code_challenge_method": []string{"S256"}, + "response_type": []string{"code"}, + "response_mode": []string{"form_post"}, + "scope": []string{"test-scope"}, + "nonce": []string{"test-nonce"}, + "state": []string{"test-state"}, + "access_type": []string{"offline"}, + "client_id": []string{"test-client-id"}, + }, actualParams) + + parsedActualURL.RawQuery = "" + require.Equal(t, formPostSuccessServer.URL+"/authorize", parsedActualURL.String()) + + go func() { + h.callbacks <- callbackResult{token: &testToken} + }() + return nil + } + return nil + } + }, + issuer: formPostSuccessServer.URL, + wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + formPostSuccessServer.URL + "\""}, + wantToken: &testToken, + }, { name: "upstream name and type are included in authorize request if upstream name is provided", clientID: "test-client-id",