Tiny bit more code for Supervisor's callback_handler.go
Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
parent
81b9a48437
commit
3ef1171667
@ -17,6 +17,7 @@ import (
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"go.pinniped.dev/internal/httputil/httperr"
|
||||
"go.pinniped.dev/internal/oidc"
|
||||
"go.pinniped.dev/internal/oidc/csrftoken"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/internal/oidcclient/nonce"
|
||||
@ -42,10 +43,6 @@ const (
|
||||
csrfCookieEncodingName = "csrf"
|
||||
)
|
||||
|
||||
type IDPListGetter interface {
|
||||
GetIDPList() []provider.UpstreamOIDCIdentityProvider
|
||||
}
|
||||
|
||||
// This is the encoding side of the securecookie.Codec interface.
|
||||
type Encoder interface {
|
||||
Encode(name string, value interface{}) (string, error)
|
||||
@ -53,7 +50,7 @@ type Encoder interface {
|
||||
|
||||
func NewHandler(
|
||||
issuer string,
|
||||
idpListGetter IDPListGetter,
|
||||
idpListGetter oidc.IDPListGetter,
|
||||
oauthHelper fosite.OAuth2Provider,
|
||||
generateCSRF func() (csrftoken.CSRFToken, error),
|
||||
generatePKCE func() (pkce.Code, error),
|
||||
@ -178,7 +175,7 @@ func grantOpenIDScopeIfRequested(authorizeRequester fosite.AuthorizeRequester) {
|
||||
}
|
||||
}
|
||||
|
||||
func chooseUpstreamIDP(idpListGetter IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) {
|
||||
func chooseUpstreamIDP(idpListGetter oidc.IDPListGetter) (*provider.UpstreamOIDCIdentityProvider, error) {
|
||||
allUpstreamIDPs := idpListGetter.GetIDPList()
|
||||
if len(allUpstreamIDPs) == 0 {
|
||||
return nil, httperr.New(
|
||||
|
@ -23,6 +23,7 @@ import (
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/internal/oidcclient/nonce"
|
||||
"go.pinniped.dev/internal/oidcclient/pkce"
|
||||
"go.pinniped.dev/internal/testutil"
|
||||
)
|
||||
|
||||
func TestAuthorizationEndpoint(t *testing.T) {
|
||||
@ -210,7 +211,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
csrf = csrfValueOverride
|
||||
}
|
||||
encoded, err := happyStateEncoder.Encode("s",
|
||||
expectedUpstreamStateParamFormat{
|
||||
testutil.ExpectedUpstreamStateParamFormat{
|
||||
P: encodeQuery(modifiedHappyGetRequestQueryMap(queryOverrides)),
|
||||
N: happyNonce,
|
||||
C: csrf,
|
||||
@ -270,7 +271,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "happy path using GET without a CSRF cookie",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -288,7 +289,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "happy path using GET with a CSRF cookie",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -306,7 +307,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "happy path using POST",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -326,7 +327,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "happy path when downstream redirect uri matches what is configured for client except for the port number",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -348,7 +349,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "downstream redirect uri does not match what is configured for client",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -365,7 +366,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "downstream client does not exist",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -380,7 +381,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "response type is unsupported",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -396,7 +397,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "downstream scopes do not match what is configured for client",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -412,7 +413,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "missing response type in request",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -428,7 +429,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "missing client id in request",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -443,7 +444,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -459,7 +460,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -475,7 +476,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -491,7 +492,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -509,7 +510,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
// through that part of the fosite library.
|
||||
name: "prompt param is not allowed to have none and another legal value at the same time",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -525,7 +526,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "OIDC validations are skipped when the openid scope was not requested",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -546,7 +547,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "state does not have enough entropy",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -562,7 +563,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "error while encoding upstream state param",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -577,7 +578,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "error while encoding CSRF cookie value for new cookie",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -592,7 +593,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "error while generating CSRF token",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") },
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -607,7 +608,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "error while generating nonce",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") },
|
||||
@ -622,7 +623,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "error while generating PKCE",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") },
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -637,7 +638,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "error while decoding CSRF cookie",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
generateCSRF: happyCSRFGenerator,
|
||||
generatePKCE: happyPKCEGenerator,
|
||||
generateNonce: happyNonceGenerator,
|
||||
@ -653,7 +654,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "no upstream providers are configured",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(), // empty
|
||||
idpListGetter: testutil.NewIDPListGetter(), // empty
|
||||
method: http.MethodGet,
|
||||
path: happyGetRequestPath,
|
||||
wantStatus: http.StatusUnprocessableEntity,
|
||||
@ -663,7 +664,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "too many upstream providers are configured",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider, upstreamOIDCIdentityProvider), // more than one not allowed
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider, upstreamOIDCIdentityProvider), // more than one not allowed
|
||||
method: http.MethodGet,
|
||||
path: happyGetRequestPath,
|
||||
wantStatus: http.StatusUnprocessableEntity,
|
||||
@ -673,7 +674,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "PUT is a bad method",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
method: http.MethodPut,
|
||||
path: "/some/path",
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
@ -683,7 +684,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "PATCH is a bad method",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
method: http.MethodPatch,
|
||||
path: "/some/path",
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
@ -693,7 +694,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "DELETE is a bad method",
|
||||
issuer: issuer,
|
||||
idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
method: http.MethodDelete,
|
||||
path: "/some/path",
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
@ -805,18 +806,6 @@ func TestAuthorizationEndpoint(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// Declare a separate type from the production code to ensure that the state param's contents was serialized
|
||||
// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of
|
||||
// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality
|
||||
// assertions about the redirect URL in this test.
|
||||
type expectedUpstreamStateParamFormat struct {
|
||||
P string `json:"p"`
|
||||
N string `json:"n"`
|
||||
C string `json:"c"`
|
||||
K string `json:"k"`
|
||||
V string `json:"v"`
|
||||
}
|
||||
|
||||
type errorReturningEncoder struct {
|
||||
securecookie.Codec
|
||||
}
|
||||
@ -850,13 +839,13 @@ func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL
|
||||
|
||||
expectedQueryStateParam := expectedLocationURL.Query().Get("state")
|
||||
require.NotEmpty(t, expectedQueryStateParam)
|
||||
var expectedDecodedStateParam expectedUpstreamStateParamFormat
|
||||
var expectedDecodedStateParam testutil.ExpectedUpstreamStateParamFormat
|
||||
err = stateParamDecoder.Decode("s", expectedQueryStateParam, &expectedDecodedStateParam)
|
||||
require.NoError(t, err)
|
||||
|
||||
actualQueryStateParam := actualLocationURL.Query().Get("state")
|
||||
require.NotEmpty(t, actualQueryStateParam)
|
||||
var actualDecodedStateParam expectedUpstreamStateParamFormat
|
||||
var actualDecodedStateParam testutil.ExpectedUpstreamStateParamFormat
|
||||
err = stateParamDecoder.Decode("s", actualQueryStateParam, &actualDecodedStateParam)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -884,9 +873,3 @@ func requireEqualURLs(t *testing.T, actualURL string, expectedURL string, ignore
|
||||
}
|
||||
require.Equal(t, expectedLocationQuery, actualLocationQuery)
|
||||
}
|
||||
|
||||
func newIDPListGetter(upstreamOIDCIdentityProviders ...provider.UpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider {
|
||||
idpProvider := provider.NewDynamicUpstreamIDPProvider()
|
||||
idpProvider.SetIDPList(upstreamOIDCIdentityProviders)
|
||||
return idpProvider
|
||||
}
|
||||
|
@ -6,16 +6,43 @@ package callback
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path"
|
||||
|
||||
"go.pinniped.dev/internal/httputil/httperr"
|
||||
"go.pinniped.dev/internal/oidc"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
)
|
||||
|
||||
func NewHandler() http.Handler {
|
||||
func NewHandler(
|
||||
idpListGetter oidc.IDPListGetter,
|
||||
) http.Handler {
|
||||
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != http.MethodGet {
|
||||
return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method)
|
||||
}
|
||||
|
||||
return nil
|
||||
if r.FormValue("code") == "" {
|
||||
return httperr.New(http.StatusBadRequest, "code param not found")
|
||||
}
|
||||
|
||||
if r.FormValue("state") == "" {
|
||||
return httperr.New(http.StatusBadRequest, "state param not found")
|
||||
}
|
||||
|
||||
if findUpstreamIDPConfig(r, idpListGetter) == nil {
|
||||
return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found")
|
||||
}
|
||||
|
||||
return httperr.New(http.StatusBadRequest, "state param not valid")
|
||||
})
|
||||
}
|
||||
|
||||
func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *provider.UpstreamOIDCIdentityProvider {
|
||||
_, lastPathComponent := path.Split(r.URL.Path)
|
||||
for _, p := range idpListGetter.GetIDPList() {
|
||||
if p.Name == lastPathComponent {
|
||||
return &p
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -4,18 +4,76 @@
|
||||
package callback
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/internal/testutil"
|
||||
)
|
||||
|
||||
const (
|
||||
happyUpstreamIDPName = "upstream-idp-name"
|
||||
)
|
||||
|
||||
func TestCallbackEndpoint(t *testing.T) {
|
||||
upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth")
|
||||
require.NoError(t, err)
|
||||
otherUpstreamAuthURL, err := url.Parse("https://some-other-upstream-idp:8443/auth")
|
||||
require.NoError(t, err)
|
||||
|
||||
upstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{
|
||||
Name: happyUpstreamIDPName,
|
||||
ClientID: "some-client-id",
|
||||
AuthorizationURL: *upstreamAuthURL,
|
||||
Scopes: []string{"scope1", "scope2"},
|
||||
}
|
||||
|
||||
otherUpstreamOIDCIdentityProvider := provider.UpstreamOIDCIdentityProvider{
|
||||
Name: "other-upstream-idp-name",
|
||||
ClientID: "other-some-client-id",
|
||||
AuthorizationURL: *otherUpstreamAuthURL,
|
||||
Scopes: []string{"other-scope1", "other-scope2"},
|
||||
}
|
||||
|
||||
var stateEncoderHashKey = []byte("fake-hash-secret")
|
||||
var stateEncoderBlockKey = []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES
|
||||
var cookieEncoderHashKey = []byte("fake-hash-secret2")
|
||||
var cookieEncoderBlockKey = []byte("0123456789ABCDE2") // block encryption requires 16/24/32 bytes for AES
|
||||
require.NotEqual(t, stateEncoderHashKey, cookieEncoderHashKey)
|
||||
require.NotEqual(t, stateEncoderBlockKey, cookieEncoderBlockKey)
|
||||
|
||||
var happyStateEncoder = securecookie.New(stateEncoderHashKey, stateEncoderBlockKey)
|
||||
happyStateEncoder.SetSerializer(securecookie.JSONEncoder{})
|
||||
var happyCookieEncoder = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey)
|
||||
happyCookieEncoder.SetSerializer(securecookie.JSONEncoder{})
|
||||
|
||||
//happyCSRF := "test-csrf"
|
||||
//happyPKCE := "test-pkce"
|
||||
//happyNonce := "test-nonce"
|
||||
//
|
||||
//happyEncodedState, err := happyStateEncoder.Encode("s",
|
||||
// testutil.ExpectedUpstreamStateParamFormat{
|
||||
// P: "todo query goes here",
|
||||
// N: happyNonce,
|
||||
// C: happyCSRF,
|
||||
// K: happyPKCE,
|
||||
// V: "1",
|
||||
// },
|
||||
//)
|
||||
//require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
method string
|
||||
method string
|
||||
path string
|
||||
idpListGetter provider.DynamicUpstreamIDPProvider
|
||||
|
||||
wantStatus int
|
||||
wantBody string
|
||||
@ -27,24 +85,67 @@ func TestCallbackEndpoint(t *testing.T) {
|
||||
{
|
||||
name: "PUT method is invalid",
|
||||
method: http.MethodPut,
|
||||
path: newRequestPath().String(),
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
wantBody: "Method Not Allowed: PUT (try GET)\n",
|
||||
},
|
||||
// TODO: POST/PATCH/DELETE is invalid
|
||||
// TODO: request has body? maybe we don't need to do anything...
|
||||
// TODO: code does not exist
|
||||
// TODO: we got called twice with the same state and cookie...is this bad? might be ok if the client's first roundtrip failed
|
||||
// TODO: we got called twice with the same state and cookie and the UpstreamOIDCProvider CRD has been deleted
|
||||
// TODO: state does not exist
|
||||
// TODO: invalid signature on state
|
||||
// TODO: state is expired (the expiration is encoded in the state itself)
|
||||
// TODO: state csrf value does not match csrf cookie
|
||||
// TODO: cookie does not exist
|
||||
// TODO: invalid signature on cookie
|
||||
// TODO: state version does not match what we want
|
||||
{
|
||||
name: "POST method is invalid",
|
||||
method: http.MethodPost,
|
||||
path: newRequestPath().String(),
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
wantBody: "Method Not Allowed: POST (try GET)\n",
|
||||
},
|
||||
{
|
||||
name: "PATCH method is invalid",
|
||||
method: http.MethodPatch,
|
||||
path: newRequestPath().String(),
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
wantBody: "Method Not Allowed: PATCH (try GET)\n",
|
||||
},
|
||||
{
|
||||
name: "DELETE method is invalid",
|
||||
method: http.MethodDelete,
|
||||
path: newRequestPath().String(),
|
||||
wantStatus: http.StatusMethodNotAllowed,
|
||||
wantBody: "Method Not Allowed: DELETE (try GET)\n",
|
||||
},
|
||||
{
|
||||
name: "code param was not included on request",
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithoutCode().String(),
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: "Bad Request: code param not found\n",
|
||||
},
|
||||
{
|
||||
name: "state param was not included on request",
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithoutState().String(),
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: "Bad Request: state param not found\n",
|
||||
},
|
||||
{
|
||||
name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason",
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().WithState("this-will-not-decode").String(),
|
||||
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
|
||||
wantStatus: http.StatusBadRequest,
|
||||
wantBody: "Bad Request: state param not valid\n",
|
||||
},
|
||||
{
|
||||
name: "the UpstreamOIDCProvider CRD has been deleted",
|
||||
method: http.MethodGet,
|
||||
path: newRequestPath().String(),
|
||||
idpListGetter: testutil.NewIDPListGetter(otherUpstreamOIDCIdentityProvider),
|
||||
wantStatus: http.StatusUnprocessableEntity,
|
||||
wantBody: "Unprocessable Entity: upstream provider not found\n",
|
||||
},
|
||||
// TODO: csrf cookie does not exist on request
|
||||
// TODO: csrf cookie value cannot be decoded (e.g. invalid signture or any other decoding problem)
|
||||
// TODO: csrf value from inside state param does not match csrf cookie value
|
||||
// TODO: state's internal version does not match what we want
|
||||
|
||||
// Upstream exchange
|
||||
// TODO: we can't figure out what the upstream token endpoint is (do we get this UpstreamOIDCProvider name from the path?)
|
||||
// TODO: network call to upstream token endpoint fails
|
||||
// TODO: the upstream token endpoint returns an error
|
||||
|
||||
@ -61,14 +162,15 @@ func TestCallbackEndpoint(t *testing.T) {
|
||||
// TODO: here (e.g., id jwt cannot be verified, nonce is wrong, we didn't get refresh token, we didn't get access token, we didn't get id token, access token expires too quickly)
|
||||
|
||||
// Downstream redirect
|
||||
// TODO: we grant the openid scope if it was requested, similar to what we did in auth_handler.go
|
||||
// TODO: cannot generate auth code
|
||||
// TODO: cannot persist downstream state
|
||||
}
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
subject := NewHandler()
|
||||
req := httptest.NewRequest(test.method, "/path-is-not-yet-tested", nil /* body not yet tested */)
|
||||
subject := NewHandler(test.idpListGetter)
|
||||
req := httptest.NewRequest(test.method, test.path, nil)
|
||||
rsp := httptest.NewRecorder()
|
||||
subject.ServeHTTP(rsp, req)
|
||||
|
||||
@ -77,3 +179,55 @@ func TestCallbackEndpoint(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type requestPath struct {
|
||||
upstreamIDPName, code, state *string
|
||||
}
|
||||
|
||||
func newRequestPath() *requestPath {
|
||||
n := happyUpstreamIDPName
|
||||
c := "1234"
|
||||
s := "4321"
|
||||
return &requestPath{
|
||||
upstreamIDPName: &n,
|
||||
code: &c,
|
||||
state: &s,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *requestPath) WithUpstreamIDPName(name string) *requestPath {
|
||||
r.upstreamIDPName = &name
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *requestPath) WithCode(code string) *requestPath {
|
||||
r.code = &code
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *requestPath) WithoutCode() *requestPath {
|
||||
r.code = nil
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *requestPath) WithState(state string) *requestPath {
|
||||
r.state = &state
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *requestPath) WithoutState() *requestPath {
|
||||
r.state = nil
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *requestPath) String() string {
|
||||
path := fmt.Sprintf("/downstream-provider-name/callback/%s?", *r.upstreamIDPName)
|
||||
params := url.Values{}
|
||||
if r.code != nil {
|
||||
params.Add("code", *r.code)
|
||||
}
|
||||
if r.state != nil {
|
||||
params.Add("state", *r.state)
|
||||
}
|
||||
return path + params.Encode()
|
||||
}
|
||||
|
@ -7,6 +7,8 @@ package oidc
|
||||
import (
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/compose"
|
||||
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -49,3 +51,7 @@ func FositeOauth2Helper(oauthStore interface{}, hmacSecretOfLengthAtLeast32 []by
|
||||
compose.OAuth2PKCEFactory,
|
||||
)
|
||||
}
|
||||
|
||||
type IDPListGetter interface {
|
||||
GetIDPList() []provider.UpstreamOIDCIdentityProvider
|
||||
}
|
||||
|
@ -30,14 +30,14 @@ type Manager struct {
|
||||
providerHandlers map[string]http.Handler // map of all routes for all providers
|
||||
nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request
|
||||
dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data
|
||||
idpListGetter auth.IDPListGetter // in-memory cache of upstream IDPs
|
||||
idpListGetter oidc.IDPListGetter // in-memory cache of upstream IDPs
|
||||
}
|
||||
|
||||
// NewManager returns an empty Manager.
|
||||
// nextHandler will be invoked for any requests that could not be handled by this manager's providers.
|
||||
// dynamicJWKSProvider will be used as an in-memory cache for per-issuer JWKS data.
|
||||
// idpListGetter will be used as an in-memory cache of currently configured upstream IDPs.
|
||||
func NewManager(nextHandler http.Handler, dynamicJWKSProvider jwks.DynamicJWKSProvider, idpListGetter auth.IDPListGetter) *Manager {
|
||||
func NewManager(nextHandler http.Handler, dynamicJWKSProvider jwks.DynamicJWKSProvider, idpListGetter oidc.IDPListGetter) *Manager {
|
||||
return &Manager{
|
||||
providerHandlers: make(map[string]http.Handler),
|
||||
nextHandler: nextHandler,
|
||||
|
26
internal/testutil/oidc.go
Normal file
26
internal/testutil/oidc.go
Normal file
@ -0,0 +1,26 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package testutil
|
||||
|
||||
import "go.pinniped.dev/internal/oidc/provider"
|
||||
|
||||
// Test helpers for the OIDC package.
|
||||
|
||||
func NewIDPListGetter(upstreamOIDCIdentityProviders ...provider.UpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider {
|
||||
idpProvider := provider.NewDynamicUpstreamIDPProvider()
|
||||
idpProvider.SetIDPList(upstreamOIDCIdentityProviders)
|
||||
return idpProvider
|
||||
}
|
||||
|
||||
// Declare a separate type from the production code to ensure that the state param's contents was serialized
|
||||
// in the format that we expect, with the json keys that we expect, etc. This also ensure that the order of
|
||||
// the serialized fields is the same, which doesn't really matter expect that we can make simpler equality
|
||||
// assertions about the redirect URL in this test.
|
||||
type ExpectedUpstreamStateParamFormat struct {
|
||||
P string `json:"p"`
|
||||
N string `json:"n"`
|
||||
C string `json:"c"`
|
||||
K string `json:"k"`
|
||||
V string `json:"v"`
|
||||
}
|
Loading…
Reference in New Issue
Block a user