diff --git a/internal/oidc/oidc.go b/internal/oidc/oidc.go index 4da9cfcd..8b73e662 100644 --- a/internal/oidc/oidc.go +++ b/internal/oidc/oidc.go @@ -18,6 +18,7 @@ const ( WellKnownEndpointPath = "/.well-known/openid-configuration" AuthorizationEndpointPath = "/oauth2/authorize" TokenEndpointPath = "/oauth2/token" //nolint:gosec // ignore lint warning that this is a credential + CallbackEndpointPath = "/callback" JWKSEndpointPath = "/jwks.json" ) diff --git a/internal/oidc/provider/manager/manager.go b/internal/oidc/provider/manager/manager.go index f009693d..deb1cfc4 100644 --- a/internal/oidc/provider/manager/manager.go +++ b/internal/oidc/provider/manager/manager.go @@ -12,6 +12,7 @@ import ( "go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/auth" + "go.pinniped.dev/internal/oidc/callback" "go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/discovery" "go.pinniped.dev/internal/oidc/jwks" @@ -84,6 +85,9 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) { authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder, encoder) + callbackURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.CallbackEndpointPath + m.providerHandlers[callbackURL] = callback.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, encoder, encoder) + plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer()) } } diff --git a/internal/oidc/provider/manager/manager_test.go b/internal/oidc/provider/manager/manager_test.go index 86137abd..fdea39d5 100644 --- a/internal/oidc/provider/manager/manager_test.go +++ b/internal/oidc/provider/manager/manager_test.go @@ -42,6 +42,7 @@ func TestManager(t *testing.T) { issuer2DifferentCaseHostname = "https://exAmPlE.Com/some/path/more/deeply/nested/path" issuer2KeyID = "issuer2-key" upstreamIDPAuthorizationURL = "https://test-upstream.com/auth" + downstreamRedirectURL = "http://127.0.0.1:12345/callback" ) newGetRequest := func(url string) *http.Request { @@ -82,6 +83,18 @@ func TestManager(t *testing.T) { ) } + requireCallbackRequestToBeHandled := func(requestIssuer, requestURLSuffix string) { + recorder := httptest.NewRecorder() + + subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.CallbackEndpointPath+requestURLSuffix)) + + r.False(fallbackHandlerWasCalled) + + // Minimal check to ensure that the right endpoint was called - when we don't send a CSRF + // cookie to the callback endpoint, the callback endpoint responds with a 403. + r.Equal(http.StatusForbidden, recorder.Code) + } + requireJWKSRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedJWKKeyID string) { recorder := httptest.NewRecorder() @@ -162,7 +175,6 @@ func TestManager(t *testing.T) { requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "", issuer2KeyID) requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", issuer2KeyID) - authRedirectURI := "http://127.0.0.1/callback" authRequestParams := "?" + url.Values{ "response_type": []string{"code"}, "scope": []string{"openid profile email"}, @@ -171,7 +183,7 @@ func TestManager(t *testing.T) { "nonce": []string{"some-nonce-value"}, "code_challenge": []string{"some-challenge"}, "code_challenge_method": []string{"S256"}, - "redirect_uri": []string{authRedirectURI}, + "redirect_uri": []string{downstreamRedirectURL}, }.Encode() requireAuthorizationRequestToBeHandled(issuer1, authRequestParams, upstreamIDPAuthorizationURL) @@ -180,6 +192,18 @@ func TestManager(t *testing.T) { // Hostnames are case-insensitive, so test that we can handle that. requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) + + callbackRequestParams := "?" + url.Values{ + "code": []string{"some-code"}, + "state": []string{"some-state-value"}, + }.Encode() + + requireCallbackRequestToBeHandled(issuer1, callbackRequestParams) + requireCallbackRequestToBeHandled(issuer2, callbackRequestParams) + + // // Hostnames are case-insensitive, so test that we can handle that. + requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams) + requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams) } when("given some valid providers via SetProviders()", func() {