From 7217cf489263fd4d0e5221e26a6c0dc77f0cdded Mon Sep 17 00:00:00 2001 From: Matt Moyer Date: Mon, 21 Jun 2021 14:40:08 -0500 Subject: [PATCH] In form_post mode, expect params via POST'ed form. Signed-off-by: Matt Moyer --- pkg/oidcclient/login.go | 24 ++++++++++++++--- pkg/oidcclient/login_test.go | 52 +++++++++++++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 5 deletions(-) diff --git a/pkg/oidcclient/login.go b/pkg/oidcclient/login.go index f354b94a..cd03e4ee 100644 --- a/pkg/oidcclient/login.go +++ b/pkg/oidcclient/login.go @@ -689,13 +689,29 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req } }() - // Return HTTP 405 for anything that's not a GET. - if r.Method != http.MethodGet { - return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET") + var params url.Values + if h.useFormPost { + // Return HTTP 405 for anything that's not a POST. + if r.Method != http.MethodPost { + return httperr.Newf(http.StatusMethodNotAllowed, "wanted POST") + } + + // Parse and pull the response parameters from a application/x-www-form-urlencoded request body. + if err := r.ParseForm(); err != nil { + return httperr.Wrap(http.StatusBadRequest, "invalid form", err) + } + params = r.Form + } else { + // Return HTTP 405 for anything that's not a GET. + if r.Method != http.MethodGet { + return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET") + } + + // Pull response parameters from the URL query string. + params = r.URL.Query() } // Validate OAuth2 state and fail if it's incorrect (to block CSRF). - params := r.URL.Query() if err := h.state.Validate(params.Get("state")); err != nil { return httperr.New(http.StatusForbidden, "missing or invalid state parameter") } diff --git a/pkg/oidcclient/login_test.go b/pkg/oidcclient/login_test.go index 31d84def..0e3ce673 100644 --- a/pkg/oidcclient/login_test.go +++ b/pkg/oidcclient/login_test.go @@ -4,6 +4,7 @@ package oidcclient import ( + "bytes" "context" "encoding/json" "errors" @@ -1427,10 +1428,18 @@ func TestLogin(t *testing.T) { // nolint:gocyclo func TestHandleAuthCodeCallback(t *testing.T) { const testRedirectURI = "http://127.0.0.1:12324/callback" + withFormPostMode := func(t *testing.T) Option { + return func(h *handlerState) error { + h.useFormPost = true + return nil + } + } tests := []struct { name string method string query string + body []byte + contentType string opt func(t *testing.T) Option wantErr string wantHTTPStatus int @@ -1442,6 +1451,24 @@ func TestHandleAuthCodeCallback(t *testing.T) { wantErr: "wanted GET", wantHTTPStatus: http.StatusMethodNotAllowed, }, + { + name: "wrong method for form_post", + method: "GET", + query: "", + opt: withFormPostMode, + wantErr: "wanted POST", + wantHTTPStatus: http.StatusMethodNotAllowed, + }, + { + name: "invalid form for form_post", + method: "POST", + query: "", + contentType: "application/x-www-form-urlencoded", + body: []byte(`%`), + opt: withFormPostMode, + wantErr: `invalid form: invalid URL escape "%"`, + wantHTTPStatus: http.StatusBadRequest, + }, { name: "invalid state", query: "state=invalid", @@ -1496,6 +1523,26 @@ func TestHandleAuthCodeCallback(t *testing.T) { } }, }, + { + name: "valid form_post", + method: http.MethodPost, + contentType: "application/x-www-form-urlencoded", + body: []byte(`state=test-state&code=valid`), + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + h.useFormPost = true + h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} + h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { + mock := mockUpstream(t) + mock.EXPECT(). + ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). + Return(&oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil) + return mock + } + return nil + } + }, + }, } for _, tt := range tests { tt := tt @@ -1514,12 +1561,15 @@ func TestHandleAuthCodeCallback(t *testing.T) { defer cancel() resp := httptest.NewRecorder() - req, err := http.NewRequestWithContext(ctx, "GET", "/test-callback", nil) + req, err := http.NewRequestWithContext(ctx, "GET", "/test-callback", bytes.NewBuffer(tt.body)) require.NoError(t, err) req.URL.RawQuery = tt.query if tt.method != "" { req.Method = tt.method } + if tt.contentType != "" { + req.Header.Set("Content-Type", tt.contentType) + } err = h.handleAuthCodeCallback(resp, req) if tt.wantErr != "" {