internal/oidc/provider/manager: route to callback endpoint

Signed-off-by: Andrew Keesler <akeesler@vmware.com>
This commit is contained in:
Andrew Keesler 2020-11-20 10:42:43 -05:00
parent 8f5d1709a1
commit 488d1b663a
No known key found for this signature in database
GPG Key ID: 27CE0444346F9413
3 changed files with 31 additions and 2 deletions

View File

@ -18,6 +18,7 @@ const (
WellKnownEndpointPath = "/.well-known/openid-configuration" WellKnownEndpointPath = "/.well-known/openid-configuration"
AuthorizationEndpointPath = "/oauth2/authorize" AuthorizationEndpointPath = "/oauth2/authorize"
TokenEndpointPath = "/oauth2/token" //nolint:gosec // ignore lint warning that this is a credential TokenEndpointPath = "/oauth2/token" //nolint:gosec // ignore lint warning that this is a credential
CallbackEndpointPath = "/callback"
JWKSEndpointPath = "/jwks.json" JWKSEndpointPath = "/jwks.json"
) )

View File

@ -12,6 +12,7 @@ import (
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/auth" "go.pinniped.dev/internal/oidc/auth"
"go.pinniped.dev/internal/oidc/callback"
"go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/discovery" "go.pinniped.dev/internal/oidc/discovery"
"go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/jwks"
@ -84,6 +85,9 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath
m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder, encoder) m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder, encoder)
callbackURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.CallbackEndpointPath
m.providerHandlers[callbackURL] = callback.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, encoder, encoder)
plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer()) plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer())
} }
} }

View File

@ -42,6 +42,7 @@ func TestManager(t *testing.T) {
issuer2DifferentCaseHostname = "https://exAmPlE.Com/some/path/more/deeply/nested/path" issuer2DifferentCaseHostname = "https://exAmPlE.Com/some/path/more/deeply/nested/path"
issuer2KeyID = "issuer2-key" issuer2KeyID = "issuer2-key"
upstreamIDPAuthorizationURL = "https://test-upstream.com/auth" upstreamIDPAuthorizationURL = "https://test-upstream.com/auth"
downstreamRedirectURL = "http://127.0.0.1:12345/callback"
) )
newGetRequest := func(url string) *http.Request { newGetRequest := func(url string) *http.Request {
@ -82,6 +83,18 @@ func TestManager(t *testing.T) {
) )
} }
requireCallbackRequestToBeHandled := func(requestIssuer, requestURLSuffix string) {
recorder := httptest.NewRecorder()
subject.ServeHTTP(recorder, newGetRequest(requestIssuer+oidc.CallbackEndpointPath+requestURLSuffix))
r.False(fallbackHandlerWasCalled)
// Minimal check to ensure that the right endpoint was called - when we don't send a CSRF
// cookie to the callback endpoint, the callback endpoint responds with a 403.
r.Equal(http.StatusForbidden, recorder.Code)
}
requireJWKSRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedJWKKeyID string) { requireJWKSRequestToBeHandled := func(requestIssuer, requestURLSuffix, expectedJWKKeyID string) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@ -162,7 +175,6 @@ func TestManager(t *testing.T) {
requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "", issuer2KeyID) requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "", issuer2KeyID)
requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", issuer2KeyID) requireJWKSRequestToBeHandled(issuer2DifferentCaseHostname, "?some=query", issuer2KeyID)
authRedirectURI := "http://127.0.0.1/callback"
authRequestParams := "?" + url.Values{ authRequestParams := "?" + url.Values{
"response_type": []string{"code"}, "response_type": []string{"code"},
"scope": []string{"openid profile email"}, "scope": []string{"openid profile email"},
@ -171,7 +183,7 @@ func TestManager(t *testing.T) {
"nonce": []string{"some-nonce-value"}, "nonce": []string{"some-nonce-value"},
"code_challenge": []string{"some-challenge"}, "code_challenge": []string{"some-challenge"},
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"redirect_uri": []string{authRedirectURI}, "redirect_uri": []string{downstreamRedirectURL},
}.Encode() }.Encode()
requireAuthorizationRequestToBeHandled(issuer1, authRequestParams, upstreamIDPAuthorizationURL) requireAuthorizationRequestToBeHandled(issuer1, authRequestParams, upstreamIDPAuthorizationURL)
@ -180,6 +192,18 @@ func TestManager(t *testing.T) {
// Hostnames are case-insensitive, so test that we can handle that. // Hostnames are case-insensitive, so test that we can handle that.
requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) requireAuthorizationRequestToBeHandled(issuer1DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL)
requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL) requireAuthorizationRequestToBeHandled(issuer2DifferentCaseHostname, authRequestParams, upstreamIDPAuthorizationURL)
callbackRequestParams := "?" + url.Values{
"code": []string{"some-code"},
"state": []string{"some-state-value"},
}.Encode()
requireCallbackRequestToBeHandled(issuer1, callbackRequestParams)
requireCallbackRequestToBeHandled(issuer2, callbackRequestParams)
// // Hostnames are case-insensitive, so test that we can handle that.
requireCallbackRequestToBeHandled(issuer1DifferentCaseHostname, callbackRequestParams)
requireCallbackRequestToBeHandled(issuer2DifferentCaseHostname, callbackRequestParams)
} }
when("given some valid providers via SetProviders()", func() { when("given some valid providers via SetProviders()", func() {