In form_post mode, expect params via POST'ed form.

Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
Matt Moyer 2021-06-21 14:40:08 -05:00
parent 40c931bdc5
commit 7217cf4892
No known key found for this signature in database
GPG Key ID: EAE88AD172C5AE2D
2 changed files with 71 additions and 5 deletions

View File

@ -689,13 +689,29 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
}
}()
var params url.Values
if h.useFormPost {
// Return HTTP 405 for anything that's not a POST.
if r.Method != http.MethodPost {
return httperr.Newf(http.StatusMethodNotAllowed, "wanted POST")
}
// Parse and pull the response parameters from a application/x-www-form-urlencoded request body.
if err := r.ParseForm(); err != nil {
return httperr.Wrap(http.StatusBadRequest, "invalid form", err)
}
params = r.Form
} else {
// Return HTTP 405 for anything that's not a GET.
if r.Method != http.MethodGet {
return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET")
}
// Pull response parameters from the URL query string.
params = r.URL.Query()
}
// Validate OAuth2 state and fail if it's incorrect (to block CSRF).
params := r.URL.Query()
if err := h.state.Validate(params.Get("state")); err != nil {
return httperr.New(http.StatusForbidden, "missing or invalid state parameter")
}

View File

@ -4,6 +4,7 @@
package oidcclient
import (
"bytes"
"context"
"encoding/json"
"errors"
@ -1427,10 +1428,18 @@ func TestLogin(t *testing.T) { // nolint:gocyclo
func TestHandleAuthCodeCallback(t *testing.T) {
const testRedirectURI = "http://127.0.0.1:12324/callback"
withFormPostMode := func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
return nil
}
}
tests := []struct {
name string
method string
query string
body []byte
contentType string
opt func(t *testing.T) Option
wantErr string
wantHTTPStatus int
@ -1442,6 +1451,24 @@ func TestHandleAuthCodeCallback(t *testing.T) {
wantErr: "wanted GET",
wantHTTPStatus: http.StatusMethodNotAllowed,
},
{
name: "wrong method for form_post",
method: "GET",
query: "",
opt: withFormPostMode,
wantErr: "wanted POST",
wantHTTPStatus: http.StatusMethodNotAllowed,
},
{
name: "invalid form for form_post",
method: "POST",
query: "",
contentType: "application/x-www-form-urlencoded",
body: []byte(`%`),
opt: withFormPostMode,
wantErr: `invalid form: invalid URL escape "%"`,
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "invalid state",
query: "state=invalid",
@ -1496,6 +1523,26 @@ func TestHandleAuthCodeCallback(t *testing.T) {
}
},
},
{
name: "valid form_post",
method: http.MethodPost,
contentType: "application/x-www-form-urlencoded",
body: []byte(`state=test-state&code=valid`),
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.useFormPost = true
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
Return(&oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil)
return mock
}
return nil
}
},
},
}
for _, tt := range tests {
tt := tt
@ -1514,12 +1561,15 @@ func TestHandleAuthCodeCallback(t *testing.T) {
defer cancel()
resp := httptest.NewRecorder()
req, err := http.NewRequestWithContext(ctx, "GET", "/test-callback", nil)
req, err := http.NewRequestWithContext(ctx, "GET", "/test-callback", bytes.NewBuffer(tt.body))
require.NoError(t, err)
req.URL.RawQuery = tt.query
if tt.method != "" {
req.Method = tt.method
}
if tt.contentType != "" {
req.Header.Set("Content-Type", tt.contentType)
}
err = h.handleAuthCodeCallback(resp, req)
if tt.wantErr != "" {