diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 5bd04127..d204d361 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -5,7 +5,9 @@ package callback import ( + "fmt" "net/http" + "net/url" "path" "go.pinniped.dev/internal/httputil/httperr" @@ -20,7 +22,8 @@ func NewHandler( stateDecoder, cookieDecoder oidc.Decoder, ) http.Handler { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { - if err := validateRequest(r, stateDecoder, cookieDecoder); err != nil { + state, err := validateRequest(r, stateDecoder, cookieDecoder) + if err != nil { return err } @@ -29,43 +32,56 @@ func NewHandler( return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found") } + downstreamAuthParams, err := url.ParseQuery(state.AuthParams) + if err != nil { + panic(err) + } + + downstreamCallbackURL := fmt.Sprintf( + "%s?code=%s&state=%s", + downstreamAuthParams.Get("redirect_uri"), + url.QueryEscape("some-code"), + url.QueryEscape(downstreamAuthParams.Get("state")), + ) + http.Redirect(w, r, downstreamCallbackURL, 302) + return nil }) } -func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) error { +func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) { if r.Method != http.MethodGet { - return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) + return nil, httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) } csrfValue, err := readCSRFCookie(r, cookieDecoder) if err != nil { plog.InfoErr("error reading CSRF cookie", err) - return err + return nil, err } if r.FormValue("code") == "" { plog.Info("code param not found") - return httperr.New(http.StatusBadRequest, "code param not found") + return nil, httperr.New(http.StatusBadRequest, "code param not found") } if r.FormValue("state") == "" { plog.Info("state param not found") - return httperr.New(http.StatusBadRequest, "state param not found") + return nil, httperr.New(http.StatusBadRequest, "state param not found") } state, err := readState(r, stateDecoder) if err != nil { plog.InfoErr("error reading state", err) - return err + return nil, err } if state.CSRFToken != csrfValue { plog.InfoErr("CSRF value does not match", err) - return httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err) + return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err) } - return nil + return state, nil } func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *provider.UpstreamOIDCIdentityProvider { diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 46dff9e2..29147a08 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -5,9 +5,11 @@ package callback import ( "fmt" + "html" "net/http" "net/http/httptest" "net/url" + "regexp" "testing" "github.com/gorilla/securecookie" @@ -22,6 +24,10 @@ const ( ) func TestCallbackEndpoint(t *testing.T) { + const ( + downstreamRedirectURI = "http://127.0.0.1/callback" + ) + upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth") require.NoError(t, err) otherUpstreamAuthURL, err := url.Parse("https://some-other-upstream-idp:8443/auth") @@ -53,13 +59,25 @@ func TestCallbackEndpoint(t *testing.T) { var happyCookieCodec = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey) happyCookieCodec.SetSerializer(securecookie.JSONEncoder{}) + happyDownstreamState := "some-downstream-state" + + happyOrignalRequestParams := url.Values{ + "response_type": []string{"code"}, + "scope": []string{"openid profile email"}, + "client_id": []string{"pinniped-cli"}, + "state": []string{happyDownstreamState}, + "nonce": []string{"some-nonce-value"}, + "code_challenge": []string{"some-challenge"}, + "code_challenge_method": []string{"S256"}, + "redirect_uri": []string{downstreamRedirectURI}, + }.Encode() happyCSRF := "test-csrf" happyPKCE := "test-pkce" happyNonce := "test-nonce" happyState, err := happyStateCodec.Encode("s", testutil.ExpectedUpstreamStateParamFormat{ - P: "todo query goes here", + P: happyOrignalRequestParams, N: happyNonce, C: happyCSRF, K: happyPKCE, @@ -70,7 +88,7 @@ func TestCallbackEndpoint(t *testing.T) { wrongCSRFValueState, err := happyStateCodec.Encode("s", testutil.ExpectedUpstreamStateParamFormat{ - P: "todo query goes here", + P: happyOrignalRequestParams, N: happyNonce, C: "wrong-csrf-value", K: happyPKCE, @@ -81,7 +99,7 @@ func TestCallbackEndpoint(t *testing.T) { wrongVersionState, err := happyStateCodec.Encode("s", testutil.ExpectedUpstreamStateParamFormat{ - P: "todo query goes here", + P: happyOrignalRequestParams, N: happyNonce, C: happyCSRF, K: happyPKCE, @@ -102,11 +120,22 @@ func TestCallbackEndpoint(t *testing.T) { path string csrfCookie string - wantStatus int - wantBody string + wantStatus int + wantBody string + wantRedirectLocationRegexp string }{ // Happy path // TODO: GET with good state and cookie and successful upstream token exchange and 302 to downstream client callback with its state and code + { + name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, + wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&state=` + happyDownstreamState, + }, + // TODO: when we call the callback twice in a row, we get two different auth codes (to prove we are using an RNG for auth codes) // Pre-upstream-exchange verification { @@ -240,7 +269,31 @@ func TestCallbackEndpoint(t *testing.T) { subject.ServeHTTP(rsp, req) require.Equal(t, test.wantStatus, rsp.Code) - require.Equal(t, test.wantBody, rsp.Body.String()) + + require.False(t, test.wantBody != "" && test.wantRedirectLocationRegexp != "", "test cannot set both body and redirect assertions") + switch { + case test.wantBody != "": + require.Empty(t, rsp.Header().Values("Location")) + require.Equal(t, test.wantBody, rsp.Body.String()) + case test.wantRedirectLocationRegexp != "": + // Assert that Location header matches regular expression. + require.Len(t, rsp.Header().Values("Location"), 1) + actualLocation := rsp.Header().Get("Location") + regex := regexp.MustCompile(test.wantRedirectLocationRegexp) + submatches := regex.FindStringSubmatch(actualLocation) + require.Lenf(t, submatches, 2, "no regexp match in actualLocation: %q", actualLocation) + capturedAuthCode := submatches[1] + _ = capturedAuthCode + + // Assert capturedAuthCode storage stuff... + + // Assert that body contains anchor tag with redirect location. + anchorTagWithLocationHref := fmt.Sprintf("Found.\n\n", html.EscapeString(actualLocation)) + require.Equal(t, anchorTagWithLocationHref, rsp.Body.String()) + default: + require.Empty(t, rsp.Header().Values("Location")) + require.Empty(t, rsp.Body.String()) + } }) } }