diff --git a/internal/oidc/callback/callback_handler.go b/internal/oidc/callback/callback_handler.go index a1cc49b4..7d1f555c 100644 --- a/internal/oidc/callback/callback_handler.go +++ b/internal/oidc/callback/callback_handler.go @@ -56,7 +56,7 @@ func NewHandler(idpListGetter oidc.IDPListGetter, oauthHelper fosite.OAuth2Provi "TODO", // TODO use the nonce value from the decoded state param here ) if err != nil { - panic(err) // TODO + return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens") } var username string diff --git a/internal/oidc/callback/callback_handler_test.go b/internal/oidc/callback/callback_handler_test.go index eb6de575..afa88c4b 100644 --- a/internal/oidc/callback/callback_handler_test.go +++ b/internal/oidc/callback/callback_handler_test.go @@ -5,6 +5,7 @@ package callback import ( "context" + "errors" "fmt" "net/http" "net/http/httptest" @@ -69,6 +70,17 @@ func TestCallbackEndpoint(t *testing.T) { Scopes: []string{"other-scope1", "other-scope2"}, } + failedExchangeUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{ + Name: happyUpstreamIDPName, + ClientID: upstreamOIDCIdentityProvider.ClientID, + UsernameClaim: upstreamOIDCIdentityProvider.UsernameClaim, + GroupsClaim: upstreamOIDCIdentityProvider.GroupsClaim, + Scopes: upstreamOIDCIdentityProvider.Scopes, + ExchangeAuthcodeAndValidateTokensFunc: func(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidcclient.Token, map[string]interface{}, error) { + return oidcclient.Token{}, nil, errors.New("some exchange error") + }, + } + var stateEncoderHashKey = []byte("fake-hash-secret") var stateEncoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES var cookieEncoderHashKey = []byte("fake-hash-secret2") @@ -277,7 +289,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "the CSRF cookie does not exist on request", - idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), wantStatus: http.StatusForbidden, @@ -285,7 +297,7 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason", - idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), method: http.MethodGet, path: newRequestPath().WithState(happyState).String(), csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped", @@ -294,13 +306,24 @@ func TestCallbackEndpoint(t *testing.T) { }, { name: "cookie csrf value does not match state csrf value", - idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider), + idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider), method: http.MethodGet, path: newRequestPath().WithState(wrongCSRFValueState).String(), csrfCookie: happyCSRFCookie, wantStatus: http.StatusForbidden, wantBody: "Forbidden: CSRF value does not match\n", }, + + // Upstream exchange + { + name: "upstream auth code exchange fails", + idpListGetter: testutil.NewIDPListGetter(failedExchangeUpstreamOIDCIdentityProvider), + method: http.MethodGet, + path: newRequestPath().WithState(happyState).String(), + csrfCookie: happyCSRFCookie, + wantStatus: http.StatusBadGateway, + wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n", + }, } for _, test := range tests { test := test