When supported, use "response_mode=form_post" in client.

Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
Matt Moyer 2021-06-21 14:19:12 -05:00
parent 2823d4d1e3
commit 40c931bdc5
No known key found for this signature in database
GPG Key ID: EAE88AD172C5AE2D
2 changed files with 151 additions and 26 deletions

View File

@ -87,6 +87,7 @@ type handlerState struct {
// Generated parameters of a login flow. // Generated parameters of a login flow.
provider *oidc.Provider provider *oidc.Provider
oauth2Config *oauth2.Config oauth2Config *oauth2.Config
useFormPost bool
state state.State state state.State
nonce nonce.Nonce nonce nonce.Nonce
pkce pkce.Code pkce pkce.Code
@ -486,8 +487,14 @@ func (h *handlerState) webBrowserBasedAuth(authorizeOptions *[]oauth2.AuthCodeOp
Path: h.callbackPath, Path: h.callbackPath,
}).String() }).String()
// If the server supports it, request response_mode=form_post.
authParams := *authorizeOptions
if h.useFormPost {
authParams = append(authParams, oauth2.SetAuthURLParam("response_mode", "form_post"))
}
// Now that we have a redirect URL with the listener port, we can build the authorize URL. // Now that we have a redirect URL with the listener port, we can build the authorize URL.
authorizeURL := h.oauth2Config.AuthCodeURL(h.state.String(), *authorizeOptions...) authorizeURL := h.oauth2Config.AuthCodeURL(h.state.String(), authParams...)
// Start a callback server in a background goroutine. // Start a callback server in a background goroutine.
shutdown := h.serve(listener) shutdown := h.serve(listener)
@ -567,9 +574,27 @@ func (h *handlerState) initOIDCDiscovery() error {
Endpoint: h.provider.Endpoint(), Endpoint: h.provider.Endpoint(),
Scopes: h.scopes, Scopes: h.scopes,
} }
// Use response_mode=form_post if the provider supports it.
var discoveryClaims struct {
ResponseModesSupported []string `json:"response_modes_supported"`
}
if err := h.provider.Claims(&discoveryClaims); err != nil {
return fmt.Errorf("could not decode response_modes_supported in OIDC discovery from %q: %w", h.issuer, err)
}
h.useFormPost = stringSliceContains(discoveryClaims.ResponseModesSupported, "form_post")
return nil return nil
} }
func stringSliceContains(slice []string, s string) bool {
for _, item := range slice {
if item == s {
return true
}
}
return false
}
func (h *handlerState) tokenExchangeRFC8693(baseToken *oidctypes.Token) (*oidctypes.Token, error) { func (h *handlerState) tokenExchangeRFC8693(baseToken *oidctypes.Token) (*oidctypes.Token, error) {
h.logger.V(debugLogLevel).Info("Pinniped: Performing RFC8693 token exchange", "requestedAudience", h.requestedAudience) h.logger.V(debugLogLevel).Info("Pinniped: Performing RFC8693 token exchange", "requestedAudience", h.requestedAudience)
// Perform OIDC discovery. This may have already been performed if there was not a cached base token. // Perform OIDC discovery. This may have already been performed if there was not a cached base token.

View File

@ -80,6 +80,22 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
})) }))
t.Cleanup(errorServer.Close) t.Cleanup(errorServer.Close)
// Start a test server that returns discovery data with a broken response_modes_supported value.
brokenResponseModeMux := http.NewServeMux()
brokenResponseModeServer := httptest.NewServer(brokenResponseModeMux)
brokenResponseModeMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "application/json")
type providerJSON struct {
Issuer string `json:"issuer"`
ResponseModesSupported string `json:"response_modes_supported"` // Wrong type (should be []string).
}
_ = json.NewEncoder(w).Encode(&providerJSON{
Issuer: brokenResponseModeServer.URL,
ResponseModesSupported: "invalid",
})
})
t.Cleanup(brokenResponseModeServer.Close)
// Start a test server that returns discovery data with a broken token URL // Start a test server that returns discovery data with a broken token URL
brokenTokenURLMux := http.NewServeMux() brokenTokenURLMux := http.NewServeMux()
brokenTokenURLServer := httptest.NewServer(brokenTokenURLMux) brokenTokenURLServer := httptest.NewServer(brokenTokenURLMux)
@ -100,30 +116,29 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
}) })
t.Cleanup(brokenTokenURLServer.Close) t.Cleanup(brokenTokenURLServer.Close)
// Start a test server that returns a real discovery document and answers refresh requests. discoveryHandler := func(server *httptest.Server, responseModes []string) http.HandlerFunc {
providerMux := http.NewServeMux() return func(w http.ResponseWriter, r *http.Request) {
successServer := httptest.NewServer(providerMux) if r.Method != http.MethodGet {
t.Cleanup(successServer.Close) http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
providerMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { return
if r.Method != http.MethodGet { }
http.Error(w, "unexpected method", http.StatusMethodNotAllowed) w.Header().Set("content-type", "application/json")
return _ = json.NewEncoder(w).Encode(&struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
ResponseModesSupported []string `json:"response_modes_supported,omitempty"`
}{
Issuer: server.URL,
AuthURL: server.URL + "/authorize",
TokenURL: server.URL + "/token",
JWKSURL: server.URL + "/keys",
ResponseModesSupported: responseModes,
})
} }
w.Header().Set("content-type", "application/json") }
type providerJSON struct { tokenHandler := func(w http.ResponseWriter, r *http.Request) {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
}
_ = json.NewEncoder(w).Encode(&providerJSON{
Issuer: successServer.URL,
AuthURL: successServer.URL + "/authorize",
TokenURL: successServer.URL + "/token",
JWKSURL: successServer.URL + "/keys",
})
})
providerMux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
http.Error(w, "unexpected method", http.StatusMethodNotAllowed) http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
return return
@ -204,7 +219,21 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
w.Header().Set("content-type", "application/json") w.Header().Set("content-type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(&response)) require.NoError(t, json.NewEncoder(w).Encode(&response))
}) }
// Start a test server that returns a real discovery document and answers refresh requests.
providerMux := http.NewServeMux()
successServer := httptest.NewServer(providerMux)
t.Cleanup(successServer.Close)
providerMux.HandleFunc("/.well-known/openid-configuration", discoveryHandler(successServer, nil))
providerMux.HandleFunc("/token", tokenHandler)
// Start a test server that returns a real discovery document and answers refresh requests, _and_ supports form_mode=post.
formPostProviderMux := http.NewServeMux()
formPostSuccessServer := httptest.NewServer(formPostProviderMux)
t.Cleanup(formPostSuccessServer.Close)
formPostProviderMux.HandleFunc("/.well-known/openid-configuration", discoveryHandler(formPostSuccessServer, []string{"query", "form_post"}))
formPostProviderMux.HandleFunc("/token", tokenHandler)
defaultDiscoveryResponse := func(req *http.Request) (*http.Response, error) { // nolint:unparam defaultDiscoveryResponse := func(req *http.Request) (*http.Response, error) { // nolint:unparam
// Call the handler function from the test server to calculate the response. // Call the handler function from the test server to calculate the response.
@ -349,7 +378,7 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
wantToken: &testToken, wantToken: &testToken,
}, },
{ {
name: "discovery failure", name: "discovery failure due to 500 error",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return nil } return func(h *handlerState) error { return nil }
}, },
@ -357,6 +386,15 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + errorServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + errorServer.URL + "\""},
wantErr: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL), wantErr: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL),
}, },
{
name: "discovery failure due to invalid response_modes_supported",
opt: func(t *testing.T) Option {
return func(h *handlerState) error { return nil }
},
issuer: brokenResponseModeServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + brokenResponseModeServer.URL + "\""},
wantErr: fmt.Sprintf("could not decode response_modes_supported in OIDC discovery from %q: json: cannot unmarshal string into Go struct field .response_modes_supported of type []string", brokenResponseModeServer.URL),
},
{ {
name: "session cache hit with refreshable token", name: "session cache hit with refreshable token",
issuer: successServer.URL, issuer: successServer.URL,
@ -580,6 +618,68 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""}, wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + successServer.URL + "\""},
wantToken: &testToken, wantToken: &testToken,
}, },
{
name: "callback returns success with request_mode=form_post",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
Issuer: formPostSuccessServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
require.Equal(t, []*oidctypes.Token{&testToken}, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithClient(&http.Client{Timeout: 10 * time.Second})(h))
h.openURL = func(actualURL string) error {
parsedActualURL, err := url.Parse(actualURL)
require.NoError(t, err)
actualParams := parsedActualURL.Query()
require.Contains(t, actualParams.Get("redirect_uri"), "http://127.0.0.1:")
actualParams.Del("redirect_uri")
require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"},
"response_type": []string{"code"},
"response_mode": []string{"form_post"},
"scope": []string{"test-scope"},
"nonce": []string{"test-nonce"},
"state": []string{"test-state"},
"access_type": []string{"offline"},
"client_id": []string{"test-client-id"},
}, actualParams)
parsedActualURL.RawQuery = ""
require.Equal(t, formPostSuccessServer.URL+"/authorize", parsedActualURL.String())
go func() {
h.callbacks <- callbackResult{token: &testToken}
}()
return nil
}
return nil
}
},
issuer: formPostSuccessServer.URL,
wantLogs: []string{"\"level\"=4 \"msg\"=\"Pinniped: Performing OIDC discovery\" \"issuer\"=\"" + formPostSuccessServer.URL + "\""},
wantToken: &testToken,
},
{ {
name: "upstream name and type are included in authorize request if upstream name is provided", name: "upstream name and type are included in authorize request if upstream name is provided",
clientID: "test-client-id", clientID: "test-client-id",