diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go index 1ac7cb87..df90bde5 100644 --- a/internal/oidc/auth/auth_handler_test.go +++ b/internal/oidc/auth/auth_handler_test.go @@ -271,7 +271,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path using GET without a CSRF cookie", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -289,7 +289,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path using GET with a CSRF cookie", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -307,7 +307,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path using POST", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -327,7 +327,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "happy path when downstream redirect uri matches what is configured for client except for the port number", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -349,7 +349,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "downstream redirect uri does not match what is configured for client", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -366,7 +366,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "downstream client does not exist", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -381,7 +381,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "response type is unsupported", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -397,7 +397,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "downstream scopes do not match what is configured for client", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -413,7 +413,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing response type in request", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -429,7 +429,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing client id in request", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -444,7 +444,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -460,7 +460,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3 issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -476,7 +476,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3 issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -492,7 +492,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -510,7 +510,7 @@ func TestAuthorizationEndpoint(t *testing.T) { // through that part of the fosite library. name: "prompt param is not allowed to have none and another legal value at the same time", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -526,7 +526,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "OIDC validations are skipped when the openid scope was not requested", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -547,7 +547,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "state does not have enough entropy", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -563,7 +563,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while encoding upstream state param", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -578,7 +578,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while encoding CSRF cookie value for new cookie", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -593,7 +593,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while generating CSRF token", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") }, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -608,7 +608,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while generating nonce", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") }, @@ -623,7 +623,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while generating PKCE", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }, generateNonce: happyNonceGenerator, @@ -638,7 +638,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "error while decoding CSRF cookie", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), generateCSRF: happyCSRFGenerator, generatePKCE: happyPKCEGenerator, generateNonce: happyNonceGenerator, @@ -664,7 +664,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "too many upstream providers are configured", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider, upstreamOIDCIdentityProvider), // more than one not allowed + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider, &upstreamOIDCIdentityProvider), // more than one not allowed method: http.MethodGet, path: happyGetRequestPath, wantStatus: http.StatusUnprocessableEntity, @@ -674,7 +674,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "PUT is a bad method", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), method: http.MethodPut, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -684,7 +684,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "PATCH is a bad method", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), method: http.MethodPatch, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, @@ -694,7 +694,7 @@ func TestAuthorizationEndpoint(t *testing.T) { { name: "DELETE is a bad method", issuer: issuer, - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider), method: http.MethodDelete, path: "/some/path", wantStatus: http.StatusMethodNotAllowed, diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index 4208a746..985ade31 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -47,13 +47,14 @@ func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provi } // Grant the openid scope only if it was requested. + // TODO: shouldn't we be potentially granting more scopes than just openid... grantOpenIDScopeIfRequested(authorizeRequester) _, idTokenClaims, err := upstreamIDPConfig.ExchangeAuthcodeAndValidateTokens( r.Context(), - "TODO", // TODO use the upstream authcode (code param) here - "TODO", // TODO use the pkce value from the decoded state param here - "TODO", // TODO use the nonce value from the decoded state param here + r.URL.Query().Get("code"), // TODO: do we need to validate this? + state.PKCECode, + state.Nonce, ) if err != nil { return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens") diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index 4b47e96b..b6585a3c 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -21,7 +21,6 @@ import ( "github.com/stretchr/testify/require" "go.pinniped.dev/internal/oidc" - "go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidcclient" "go.pinniped.dev/internal/oidcclient/nonce" "go.pinniped.dev/internal/oidcclient/pkce" @@ -35,6 +34,8 @@ const ( func TestCallbackEndpoint(t *testing.T) { const ( downstreamRedirectURI = "http://127.0.0.1/callback" + + happyUpstreamAuthcode = "upstream-auth-code" ) upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ @@ -170,13 +171,19 @@ func TestCallbackEndpoint(t *testing.T) { require.NoError(t, err) happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue + happyExchangeAndValidateTokensArgs := &testutil.ExchangeAuthcodeAndValidateTokenArgs{ + Authcode: happyUpstreamAuthcode, + PKCECodeVerifier: pkce.Code(happyPKCE), + ExpectedIDTokenNonce: nonce.Nonce(happyNonce), + } + tests := []struct { name string - idpListGetter provider.DynamicUpstreamIDPProvider - method string - path string - csrfCookie string + idp testutil.TestUpstreamOIDCIdentityProvider + method string + path string + csrfCookie string wantStatus int wantBody string @@ -184,19 +191,22 @@ func TestCallbackEndpoint(t *testing.T) { // TODO: I am unused... wantAuthcodeStored bool wantGrantedOpenidScope bool + + wantExchangeAndValidateTokensCall *testutil.ExchangeAuthcodeAndValidateTokenArgs }{ { - name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusFound, + name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, // Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it - wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=openid&state=` + happyDownstreamState, - wantAuthcodeStored: true, - wantGrantedOpenidScope: true, - wantBody: "", + wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=openid&state=` + happyDownstreamState, + wantAuthcodeStored: true, + wantGrantedOpenidScope: true, + wantBody: "", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, // TODO: when we call the callback twice in a row, we get two different auth codes (to prove we are using an RNG for auth codes) @@ -246,95 +256,97 @@ func TestCallbackEndpoint(t *testing.T) { wantBody: "Bad Request: state param not found\n", }, { - name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState("this-will-not-decode").String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusBadRequest, - wantBody: "Bad Request: error reading state\n", + name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState("this-will-not-decode").String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: error reading state\n", }, { - name: "state's internal version does not match what we want", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(wrongVersionState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusUnprocessableEntity, - wantBody: "Unprocessable Entity: state format version is invalid\n", + name: "state's internal version does not match what we want", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(wrongVersionState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: state format version is invalid\n", }, { - name: "state's downstream auth params element is invalid", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(wrongDownstreamAuthParamsState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusBadRequest, - wantBody: "Bad Request: error reading state downstream auth params\n", + name: "state's downstream auth params element is invalid", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(wrongDownstreamAuthParamsState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: error reading state downstream auth params\n", }, { - name: "state's downstream auth params are missing required value (e.g., client_id)", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(missingClientIDState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusBadRequest, - wantBody: "Bad Request: error using state downstream auth params\n", + name: "state's downstream auth params are missing required value (e.g., client_id)", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(missingClientIDState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadRequest, + wantBody: "Bad Request: error using state downstream auth params\n", }, { - name: "state's downstream auth params does not contain openid scope", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(noOpenidScopeState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusFound, - wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState, + name: "state's downstream auth params does not contain openid scope", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(noOpenidScopeState).WithCode(happyUpstreamAuthcode).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusFound, + wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState, + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, { - name: "the UpstreamOIDCProvider CRD has been deleted", - idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusUnprocessableEntity, - wantBody: "Unprocessable Entity: upstream provider not found\n", + name: "the UpstreamOIDCProvider CRD has been deleted", + idp: otherUpstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusUnprocessableEntity, + wantBody: "Unprocessable Entity: upstream provider not found\n", }, { - name: "the CSRF cookie does not exist on request", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - wantStatus: http.StatusForbidden, - wantBody: "Forbidden: CSRF cookie is missing\n", + name: "the CSRF cookie does not exist on request", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + wantStatus: http.StatusForbidden, + wantBody: "Forbidden: CSRF cookie is missing\n", }, { - name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", - wantStatus: http.StatusForbidden, - wantBody: "Forbidden: error reading CSRF cookie\n", + name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", + wantStatus: http.StatusForbidden, + wantBody: "Forbidden: error reading CSRF cookie\n", }, { - name: "cookie csrf value does not match state csrf value", - idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(wrongCSRFValueState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusForbidden, - wantBody: "Forbidden: CSRF value does not match\n", + name: "cookie csrf value does not match state csrf value", + idp: upstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(wrongCSRFValueState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusForbidden, + wantBody: "Forbidden: CSRF value does not match\n", }, // Upstream exchange { - name: "upstream auth code exchange fails", - idpListGetter: testutil.NewIDPListGetter(failedExchangeUpstreamOIDCIdentityProvider), - method: http.MethodGet, - path: newRequestPath().WithState(happyState).String(), - csrfCookie: happyCSRFCookie, - wantStatus: http.StatusBadGateway, - wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n", + name: "upstream auth code exchange fails", + idp: failedExchangeUpstreamOIDCIdentityProvider, + method: http.MethodGet, + path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadGateway, + wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n", + wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs, }, } for _, test := range tests { @@ -352,7 +364,8 @@ func TestCallbackEndpoint(t *testing.T) { require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret) - subject := NewHandler(test.idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec) + idpListGetter := testutil.NewIDPListGetter(&test.idp) + subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec) req := httptest.NewRequest(test.method, test.path, nil) if test.csrfCookie != "" { req.Header.Set("Cookie", test.csrfCookie) @@ -408,6 +421,14 @@ func TestCallbackEndpoint(t *testing.T) { } else { require.Empty(t, rsp.Header().Values("Location")) } + + if test.wantExchangeAndValidateTokensCall != nil { + require.Equal(t, 1, test.idp.ExchangeAuthcodeAndValidateTokensCallCount()) + test.wantExchangeAndValidateTokensCall.Ctx = req.Context() + require.Equal(t, test.wantExchangeAndValidateTokensCall, test.idp.ExchangeAuthcodeAndValidateTokensArgs(0)) + } else { + require.Equal(t, 0, test.idp.ExchangeAuthcodeAndValidateTokensCallCount()) + } }) } } diff --git a/internal/oidc/provider/manager/manager_test.go b/internal/oidc/provider/manager/manager_test.go index e9e175d8..fce748bb 100644 --- a/internal/oidc/provider/manager/manager_test.go +++ b/internal/oidc/provider/manager/manager_test.go @@ -109,7 +109,7 @@ func TestManager(t *testing.T) { parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL) r.NoError(err) - idpListGetter := testutil.NewIDPListGetter(testutil.TestUpstreamOIDCIdentityProvider{ + idpListGetter := testutil.NewIDPListGetter(&testutil.TestUpstreamOIDCIdentityProvider{ Name: "test-idp", ClientID: "test-client-id", AuthorizationURL: *parsedUpstreamIDPAuthorizationURL, diff --git a/internal/testutil/oidc.go b/internal/testutil/oidc.go index 14bdb92b..7cbfcf81 100644 --- a/internal/testutil/oidc.go +++ b/internal/testutil/oidc.go @@ -15,6 +15,15 @@ import ( // Test helpers for the OIDC package. +// ExchangeAuthcodeAndValidateTokenArgs is a POGO (plain old go object?) used to spy on calls to +// TestUpstreamOIDCIdentityProvider.ExchangeAuthcodeAndValidateTokensFunc(). +type ExchangeAuthcodeAndValidateTokenArgs struct { + Ctx context.Context + Authcode string + PKCECodeVerifier pkce.Code + ExpectedIDTokenNonce nonce.Nonce +} + type TestUpstreamOIDCIdentityProvider struct { Name string ClientID string @@ -28,6 +37,9 @@ type TestUpstreamOIDCIdentityProvider struct { pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, ) (oidcclient.Token, map[string]interface{}, error) + + exchangeAuthcodeAndValidateTokensCallCount int + exchangeAuthcodeAndValidateTokensArgs []*ExchangeAuthcodeAndValidateTokenArgs } func (u *TestUpstreamOIDCIdentityProvider) GetName() string { @@ -60,14 +72,35 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens( pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, ) (oidcclient.Token, map[string]interface{}, error) { + if u.exchangeAuthcodeAndValidateTokensArgs == nil { + u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) + } + u.exchangeAuthcodeAndValidateTokensCallCount++ + u.exchangeAuthcodeAndValidateTokensArgs = append(u.exchangeAuthcodeAndValidateTokensArgs, &ExchangeAuthcodeAndValidateTokenArgs{ + Ctx: ctx, + Authcode: authcode, + PKCECodeVerifier: pkceCodeVerifier, + ExpectedIDTokenNonce: expectedIDTokenNonce, + }) return u.ExchangeAuthcodeAndValidateTokensFunc(ctx, authcode, pkceCodeVerifier, expectedIDTokenNonce) } -func NewIDPListGetter(upstreamOIDCIdentityProviders ...TestUpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider { +func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensCallCount() int { + return u.exchangeAuthcodeAndValidateTokensCallCount +} + +func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokensArgs(call int) *ExchangeAuthcodeAndValidateTokenArgs { + if u.exchangeAuthcodeAndValidateTokensArgs == nil { + u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) + } + return u.exchangeAuthcodeAndValidateTokensArgs[call] +} + +func NewIDPListGetter(upstreamOIDCIdentityProviders ...*TestUpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider { idpProvider := provider.NewDynamicUpstreamIDPProvider() upstreams := make([]provider.UpstreamOIDCIdentityProviderI, len(upstreamOIDCIdentityProviders)) for i := range upstreamOIDCIdentityProviders { - upstreams[i] = provider.UpstreamOIDCIdentityProviderI(&upstreamOIDCIdentityProviders[i]) + upstreams[i] = provider.UpstreamOIDCIdentityProviderI(upstreamOIDCIdentityProviders[i]) } idpProvider.SetIDPList(upstreams) return idpProvider