Tiny bit more code for Supervisor's callback_handler.go

Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
Andrew Keesler 2020-11-13 15:59:51 -08:00 committed by Ryan Richard
parent 81b9a48437
commit 3ef1171667
7 changed files with 268 additions and 75 deletions

View File

@ -17,6 +17,7 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidcclient/nonce" "go.pinniped.dev/internal/oidcclient/nonce"
@ -42,10 +43,6 @@ const (
csrfCookieEncodingName = "csrf" csrfCookieEncodingName = "csrf"
) )
type IDPListGetter interface {
GetIDPList() []provider.UpstreamOIDCIdentityProvider
}
// This is the encoding side of the securecookie.Codec interface. // This is the encoding side of the securecookie.Codec interface.
type Encoder interface { type Encoder interface {
Encode(name string, value interface{}) (string, error) Encode(name string, value interface{}) (string, error)
@ -53,7 +50,7 @@ type Encoder interface {
func NewHandler( func NewHandler(
issuer string, issuer string,
idpListGetter IDPListGetter, idpListGetter oidc.IDPListGetter,
oauthHelper fosite.OAuth2Provider, oauthHelper fosite.OAuth2Provider,
generateCSRF func() (csrftoken.CSRFToken, error), generateCSRF func() (csrftoken.CSRFToken, error),
generatePKCE func() (pkce.Code, error), generatePKCE func() (pkce.Code, error),
@ -178,7 +175,7 @@ func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) {
} }
} }
func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) { func chooseUpstreamIDP(idpListGetter oidc.IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) {
allUpstreamIDPs := idpListGetter.GetIDPList() allUpstreamIDPs := idpListGetter.GetIDPList()
if len(allUpstreamIDPs) == 0 { if len(allUpstreamIDPs) == 0 {
return nil, httperr.New( return nil, httperr.New(

View File

@ -23,6 +23,7 @@ import (
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidcclient/nonce" "go.pinniped.dev/internal/oidcclient/nonce"
"go.pinniped.dev/internal/oidcclient/pkce" "go.pinniped.dev/internal/oidcclient/pkce"
"go.pinniped.dev/internal/testutil"
) )
func TestAuthorizationEndpoint(t *testing.T) { func TestAuthorizationEndpoint(t *testing.T) {
@ -210,7 +211,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
csrf = csrfValueOverride csrf = csrfValueOverride
} }
encoded, err := happyStateEncoder.Encode("s", encoded, err := happyStateEncoder.Encode("s",
expectedUpstreamStateParamFormat{ testutil.ExpectedUpstreamStateParamFormat{
P: encodeQuery(modifiedHappyGetRequestQueryMap(queryOverrides)), P: encodeQuery(modifiedHappyGetRequestQueryMap(queryOverrides)),
N: happyNonce, N: happyNonce,
C: csrf, C: csrf,
@ -270,7 +271,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "happy path using GET without a CSRF cookie", name: "happy path using GET without a CSRF cookie",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -288,7 +289,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "happy path using GET with a CSRF cookie", name: "happy path using GET with a CSRF cookie",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -306,7 +307,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "happy path using POST", name: "happy path using POST",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -326,7 +327,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "happy path when downstream redirect uri matches what is configured for client except for the port number", name: "happy path when downstream redirect uri matches what is configured for client except for the port number",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -348,7 +349,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "downstream redirect uri does not match what is configured for client", name: "downstream redirect uri does not match what is configured for client",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -365,7 +366,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "downstream client does not exist", name: "downstream client does not exist",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -380,7 +381,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "response type is unsupported", name: "response type is unsupported",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -396,7 +397,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "downstream scopes do not match what is configured for client", name: "downstream scopes do not match what is configured for client",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -412,7 +413,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "missing response type in request", name: "missing response type in request",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -428,7 +429,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "missing client id in request", name: "missing client id in request",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -443,7 +444,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -459,7 +460,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3 name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -475,7 +476,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3 name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -491,7 +492,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1 name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -509,7 +510,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
// through that part of the fosite library. // through that part of the fosite library.
name: "prompt param is not allowed to have none and another legal value at the same time", name: "prompt param is not allowed to have none and another legal value at the same time",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -525,7 +526,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "OIDC validations are skipped when the openid scope was not requested", name: "OIDC validations are skipped when the openid scope was not requested",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -546,7 +547,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "state does not have enough entropy", name: "state does not have enough entropy",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -562,7 +563,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "error while encoding upstream state param", name: "error while encoding upstream state param",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -577,7 +578,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "error while encoding CSRF cookie value for new cookie", name: "error while encoding CSRF cookie value for new cookie",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -592,7 +593,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "error while generating CSRF token", name: "error while generating CSRF token",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") }, generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") },
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -607,7 +608,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "error while generating nonce", name: "error while generating nonce",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") }, generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") },
@ -622,7 +623,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "error while generating PKCE", name: "error while generating PKCE",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") }, generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") },
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -637,7 +638,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "error while decoding CSRF cookie", name: "error while decoding CSRF cookie",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator, generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator, generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator, generateNonce: happyNonceGenerator,
@ -653,7 +654,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "no upstream providers are configured", name: "no upstream providers are configured",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(), // empty idpListGetter: testutil.NewIDPListGetter(), // empty
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
wantStatus: http.StatusUnprocessableEntity, wantStatus: http.StatusUnprocessableEntity,
@ -663,7 +664,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "too many upstream providers are configured", name: "too many upstream providers are configured",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider, upstreamOIDCIdentityProvider), // more than one not allowed idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider, upstreamOIDCIdentityProvider), // more than one not allowed
method: http.MethodGet, method: http.MethodGet,
path: happyGetRequestPath, path: happyGetRequestPath,
wantStatus: http.StatusUnprocessableEntity, wantStatus: http.StatusUnprocessableEntity,
@ -673,7 +674,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "PUT is a bad method", name: "PUT is a bad method",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
method: http.MethodPut, method: http.MethodPut,
path: "/some/path", path: "/some/path",
wantStatus: http.StatusMethodNotAllowed, wantStatus: http.StatusMethodNotAllowed,
@ -683,7 +684,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "PATCH is a bad method", name: "PATCH is a bad method",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
method: http.MethodPatch, method: http.MethodPatch,
path: "/some/path", path: "/some/path",
wantStatus: http.StatusMethodNotAllowed, wantStatus: http.StatusMethodNotAllowed,
@ -693,7 +694,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{ {
name: "DELETE is a bad method", name: "DELETE is a bad method",
issuer: issuer, issuer: issuer,
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider), idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
method: http.MethodDelete, method: http.MethodDelete,
path: "/some/path", path: "/some/path",
wantStatus: http.StatusMethodNotAllowed, wantStatus: http.StatusMethodNotAllowed,
@ -805,18 +806,6 @@ func TestAuthorizationEndpoint(t *testing.T) {
}) })
} }
// Declare a separate type from the production code to ensure that the state param's contents was serialized
// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of
// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality
// assertions about the redirect URL in this test.
type expectedUpstreamStateParamFormat struct {
P string `json:"p"`
N string `json:"n"`
C string `json:"c"`
K string `json:"k"`
V string `json:"v"`
}
type errorReturningEncoder struct { type errorReturningEncoder struct {
securecookie.Codec securecookie.Codec
} }
@ -850,13 +839,13 @@ func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL
expectedQueryStateParam := expectedLocationURL.Query().Get("state") expectedQueryStateParam := expectedLocationURL.Query().Get("state")
require.NotEmpty(t, expectedQueryStateParam) require.NotEmpty(t, expectedQueryStateParam)
var expectedDecodedStateParam expectedUpstreamStateParamFormat var expectedDecodedStateParam testutil.ExpectedUpstreamStateParamFormat
err = stateParamDecoder.Decode("s", expectedQueryStateParam, &expectedDecodedStateParam) err = stateParamDecoder.Decode("s", expectedQueryStateParam, &expectedDecodedStateParam)
require.NoError(t, err) require.NoError(t, err)
actualQueryStateParam := actualLocationURL.Query().Get("state") actualQueryStateParam := actualLocationURL.Query().Get("state")
require.NotEmpty(t, actualQueryStateParam) require.NotEmpty(t, actualQueryStateParam)
var actualDecodedStateParam expectedUpstreamStateParamFormat var actualDecodedStateParam testutil.ExpectedUpstreamStateParamFormat
err = stateParamDecoder.Decode("s", actualQueryStateParam, &actualDecodedStateParam) err = stateParamDecoder.Decode("s", actualQueryStateParam, &actualDecodedStateParam)
require.NoError(t, err) require.NoError(t, err)
@ -884,9 +873,3 @@ func requireEqualURLs(t *testing.T, actualURL string, expectedURL string, ignore
} }
require.Equal(t, expectedLocationQuery, actualLocationQuery) require.Equal(t, expectedLocationQuery, actualLocationQuery)
} }
func newIDPListGetter(upstreamOIDCIdentityProviders ...provider.UpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider {
idpProvider := provider.NewDynamicUpstreamIDPProvider()
idpProvider.SetIDPList(upstreamOIDCIdentityProviders)
return idpProvider
}

View File

@ -6,16 +6,43 @@ package callback
import ( import (
"net/http" "net/http"
"path"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/provider"
) )
func NewHandler() http.Handler { func NewHandler(
idpListGetter oidc.IDPListGetter,
) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method)
} }
return nil if r.FormValue("code") == "" {
return httperr.New(http.StatusBadRequest, "code param not found")
}
if r.FormValue("state") == "" {
return httperr.New(http.StatusBadRequest, "state param not found")
}
if findUpstreamIDPConfig(r, idpListGetter) == nil {
return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found")
}
return httperr.New(http.StatusBadRequest, "state param not valid")
}) })
} }
func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *provider.UpstreamOIDCIdentityProvider {
_, lastPathComponent := path.Split(r.URL.Path)
for _, p := range idpListGetter.GetIDPList() {
if p.Name == lastPathComponent {
return &p
}
}
return nil
}

View File

@ -4,18 +4,76 @@
package callback package callback
import ( import (
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url"
"testing" "testing"
"github.com/gorilla/securecookie"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/testutil"
)
const (
happyUpstreamIDPName = "upstream-idp-name"
) )
func TestCallbackEndpoint(t *testing.T) { func TestCallbackEndpoint(t *testing.T) {
upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth")
require.NoError(t, err)
otherUpstreamAuthURL, err := url.Parse("https://some-other-upstream-idp:8443/auth")
require.NoError(t, err)
upstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{
Name: happyUpstreamIDPName,
ClientID: "some-client-id",
AuthorizationURL: *upstreamAuthURL,
Scopes: []string{"scope1", "scope2"},
}
otherUpstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{
Name: "other-upstream-idp-name",
ClientID: "other-some-client-id",
AuthorizationURL: *otherUpstreamAuthURL,
Scopes: []string{"other-scope1", "other-scope2"},
}
var stateEncoderHashKey = []byte("fake-hash-secret")
var stateEncoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES
var cookieEncoderHashKey = []byte("fake-hash-secret2")
var cookieEncoderBlockKey = []byte("0123456789ABCDE2") // block encryption requires 16/24/32 bytes for AES
require.NotEqual(t, stateEncoderHashKey, cookieEncoderHashKey)
require.NotEqual(t, stateEncoderBlockKey, cookieEncoderBlockKey)
var happyStateEncoder = securecookie.New(stateEncoderHashKey, stateEncoderBlockKey)
happyStateEncoder.SetSerializer(securecookie.JSONEncoder{})
var happyCookieEncoder = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey)
happyCookieEncoder.SetSerializer(securecookie.JSONEncoder{})
//happyCSRF := "test-csrf"
//happyPKCE := "test-pkce"
//happyNonce := "test-nonce"
//
//happyEncodedState, err := happyStateEncoder.Encode("s",
// testutil.ExpectedUpstreamStateParamFormat{
// P: "todo query goes here",
// N: happyNonce,
// C: happyCSRF,
// K: happyPKCE,
// V: "1",
// },
//)
//require.NoError(t, err)
tests := []struct { tests := []struct {
name string name string
method string method string
path string
idpListGetter provider.DynamicUpstreamIDPProvider
wantStatus int wantStatus int
wantBody string wantBody string
@ -27,24 +85,67 @@ func TestCallbackEndpoint(t *testing.T) {
{ {
name: "PUT method is invalid", name: "PUT method is invalid",
method: http.MethodPut, method: http.MethodPut,
path: newRequestPath().String(),
wantStatus: http.StatusMethodNotAllowed, wantStatus: http.StatusMethodNotAllowed,
wantBody: "Method Not Allowed: PUT (try GET)\n", wantBody: "Method Not Allowed: PUT (try GET)\n",
}, },
// TODO: POST/PATCH/DELETE is invalid {
// TODO: request has body? maybe we don't need to do anything... name: "POST method is invalid",
// TODO: code does not exist method: http.MethodPost,
// TODO: we got called twice with the same state and cookie...is this bad? might be ok if the client's first roundtrip failed path: newRequestPath().String(),
// TODO: we got called twice with the same state and cookie and the UpstreamOIDCProvider CRD has been deleted wantStatus: http.StatusMethodNotAllowed,
// TODO: state does not exist wantBody: "Method Not Allowed: POST (try GET)\n",
// TODO: invalid signature on state },
// TODO: state is expired (the expiration is encoded in the state itself) {
// TODO: state csrf value does not match csrf cookie name: "PATCH method is invalid",
// TODO: cookie does not exist method: http.MethodPatch,
// TODO: invalid signature on cookie path: newRequestPath().String(),
// TODO: state version does not match what we want wantStatus: http.StatusMethodNotAllowed,
wantBody: "Method Not Allowed: PATCH (try GET)\n",
},
{
name: "DELETE method is invalid",
method: http.MethodDelete,
path: newRequestPath().String(),
wantStatus: http.StatusMethodNotAllowed,
wantBody: "Method Not Allowed: DELETE (try GET)\n",
},
{
name: "code param was not included on request",
method: http.MethodGet,
path: newRequestPath().WithoutCode().String(),
wantStatus: http.StatusBadRequest,
wantBody: "Bad Request: code param not found\n",
},
{
name: "state param was not included on request",
method: http.MethodGet,
path: newRequestPath().WithoutState().String(),
wantStatus: http.StatusBadRequest,
wantBody: "Bad Request: state param not found\n",
},
{
name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason",
method: http.MethodGet,
path: newRequestPath().WithState("this-will-not-decode").String(),
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
wantStatus: http.StatusBadRequest,
wantBody: "Bad Request: state param not valid\n",
},
{
name: "the UpstreamOIDCProvider CRD has been deleted",
method: http.MethodGet,
path: newRequestPath().String(),
idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider),
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: upstream provider not found\n",
},
// TODO: csrf cookie does not exist on request
// TODO: csrf cookie value cannot be decoded (e.g. invalid signture or any other decoding problem)
// TODO: csrf value from inside state param does not match csrf cookie value
// TODO: state's internal version does not match what we want
// Upstream exchange // Upstream exchange
// TODO: we can't figure out what the upstream token endpoint is (do we get this UpstreamOIDCProvider name from the path?)
// TODO: network call to upstream token endpoint fails // TODO: network call to upstream token endpoint fails
// TODO: the upstream token endpoint returns an error // TODO: the upstream token endpoint returns an error
@ -61,14 +162,15 @@ func TestCallbackEndpoint(t *testing.T) {
// TODO: here (e.g., id jwt cannot be verified, nonce is wrong, we didn't get refresh token, we didn't get access token, we didn't get id token, access token expires too quickly) // TODO: here (e.g., id jwt cannot be verified, nonce is wrong, we didn't get refresh token, we didn't get access token, we didn't get id token, access token expires too quickly)
// Downstream redirect // Downstream redirect
// TODO: we grant the openid scope if it was requested, similar to what we did in auth_handler.go
// TODO: cannot generate auth code // TODO: cannot generate auth code
// TODO: cannot persist downstream state // TODO: cannot persist downstream state
} }
for _, test := range tests { for _, test := range tests {
test := test test := test
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
subject := NewHandler() subject := NewHandler(test.idpListGetter)
req := httptest.NewRequest(test.method, "/path-is-not-yet-tested", nil /* body not yet tested */) req := httptest.NewRequest(test.method, test.path, nil)
rsp := httptest.NewRecorder() rsp := httptest.NewRecorder()
subject.ServeHTTP(rsp, req) subject.ServeHTTP(rsp, req)
@ -77,3 +179,55 @@ func TestCallbackEndpoint(t *testing.T) {
}) })
} }
} }
type requestPath struct {
upstreamIDPName, code, state *string
}
func newRequestPath() *requestPath {
n := happyUpstreamIDPName
c := "1234"
s := "4321"
return &requestPath{
upstreamIDPName: &n,
code: &c,
state: &s,
}
}
func (r *requestPath) WithUpstreamIDPName(name string) *requestPath {
r.upstreamIDPName = &name
return r
}
func (r *requestPath) WithCode(code string) *requestPath {
r.code = &code
return r
}
func (r *requestPath) WithoutCode() *requestPath {
r.code = nil
return r
}
func (r *requestPath) WithState(state string) *requestPath {
r.state = &state
return r
}
func (r *requestPath) WithoutState() *requestPath {
r.state = nil
return r
}
func (r *requestPath) String() string {
path := fmt.Sprintf("/downstream-provider-name/callback/%s?", *r.upstreamIDPName)
params := url.Values{}
if r.code != nil {
params.Add("code", *r.code)
}
if r.state != nil {
params.Add("state", *r.state)
}
return path + params.Encode()
}

View File

@ -7,6 +7,8 @@ package oidc
import ( import (
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/fosite/compose" "github.com/ory/fosite/compose"
"go.pinniped.dev/internal/oidc/provider"
) )
const ( const (
@ -49,3 +51,7 @@ func FositeOauth2Helper(oauthStore interface{}, hmacSecretOfLengthAtLeast32 []by
compose.OAuth2PKCEFactory, compose.OAuth2PKCEFactory,
) )
} }
type IDPListGetter interface {
GetIDPList() []provider.UpstreamOIDCIdentityProvider
}

View File

@ -30,14 +30,14 @@ type Manager struct {
providerHandlers map[string]http.Handler // map of all routes for all providers providerHandlers map[string]http.Handler // map of all routes for all providers
nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request
dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data
idpListGetter auth.IDPListGetter // in-memory cache of upstream IDPs idpListGetter oidc.IDPListGetter // in-memory cache of upstream IDPs
} }
// NewManager returns an empty Manager. // NewManager returns an empty Manager.
// nextHandler will be invoked for any requests that could not be handled by this manager's providers. // nextHandler will be invoked for any requests that could not be handled by this manager's providers.
// dynamicJWKSProvider will be used as an in-memory cache for per-issuer JWKS data. // dynamicJWKSProvider will be used as an in-memory cache for per-issuer JWKS data.
// idpListGetter will be used as an in-memory cache of currently configured upstream IDPs. // idpListGetter will be used as an in-memory cache of currently configured upstream IDPs.
func NewManager(nextHandler http.Handler, dynamicJWKSProvider jwks.DynamicJWKSProvider, idpListGetter auth.IDPListGetter) *Manager { func NewManager(nextHandler http.Handler, dynamicJWKSProvider jwks.DynamicJWKSProvider, idpListGetter oidc.IDPListGetter) *Manager {
return &Manager{ return &Manager{
providerHandlers: make(map[string]http.Handler), providerHandlers: make(map[string]http.Handler),
nextHandler: nextHandler, nextHandler: nextHandler,

26
internal/testutil/oidc.go Normal file
View File

@ -0,0 +1,26 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package testutil
import "go.pinniped.dev/internal/oidc/provider"
// Test helpers for the OIDC package.
func NewIDPListGetter(upstreamOIDCIdentityProviders ...provider.UpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider {
idpProvider := provider.NewDynamicUpstreamIDPProvider()
idpProvider.SetIDPList(upstreamOIDCIdentityProviders)
return idpProvider
}
// Declare a separate type from the production code to ensure that the state param's contents was serialized
// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of
// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality
// assertions about the redirect URL in this test.
type ExpectedUpstreamStateParamFormat struct {
P string `json:"p"`
N string `json:"n"`
C string `json:"c"`
K string `json:"k"`
V string `json:"v"`
}