callback_handler.go: Prepend iss to sub when making default username

- Also handle several more error cases
- Move RequireTimeInDelta to shared testutils package so other tests
  can also use it
- Move all of the oidc test helpers into a new oidc/oidctestutils
  package to break a circular import dependency. The shared testutil
  package can't depend on any of our other packages or else we
  end up with circular dependencies.
- Lots more assertions about what was stored at the end of the
  request to build confidence that we are going to pass all of the
  right settings over to the token endpoint through the storage, and
  also to avoid accidental regressions in that area in the future

Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
Andrew Keesler 2020-11-19 17:57:07 -08:00 committed by Ryan Richard
parent b49d37ca54
commit b25696a1fb
7 changed files with 345 additions and 202 deletions

View File

@ -20,10 +20,10 @@ import (
"go.pinniped.dev/internal/here"
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/oidctestutil"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidcclient/nonce"
"go.pinniped.dev/internal/oidcclient/pkce"
"go.pinniped.dev/internal/testutil"
)
func TestAuthorizationEndpoint(t *testing.T) {
@ -114,7 +114,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth")
require.NoError(t, err)
upstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{
upstreamOIDCIdentityProvider := oidctestutil.TestUpstreamOIDCIdentityProvider{
Name: "some-idp",
ClientID: "some-client-id",
AuthorizationURL: *upstreamAuthURL,
@ -210,7 +210,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
csrf = csrfValueOverride
}
encoded, err := happyStateEncoder.Encode("s",
testutil.ExpectedUpstreamStateParamFormat{
oidctestutil.ExpectedUpstreamStateParamFormat{
P: encodeQuery(modifiedHappyGetRequestQueryMap(queryOverrides)),
N: happyNonce,
C: csrf,
@ -270,7 +270,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "happy path using GET without a CSRF cookie",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -288,7 +288,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "happy path using GET with a CSRF cookie",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -306,7 +306,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "happy path using POST",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -326,7 +326,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "happy path when downstream redirect uri matches what is configured for client except for the port number",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -348,7 +348,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "downstream redirect uri does not match what is configured for client",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -365,7 +365,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "downstream client does not exist",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -380,7 +380,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "response type is unsupported",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -396,7 +396,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "downstream scopes do not match what is configured for client",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -412,7 +412,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "missing response type in request",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -428,7 +428,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "missing client id in request",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -443,7 +443,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "missing PKCE code_challenge in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -459,7 +459,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "invalid value for PKCE code_challenge_method in request", // https://tools.ietf.org/html/rfc7636#section-4.3
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -475,7 +475,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "when PKCE code_challenge_method in request is `plain`", // https://tools.ietf.org/html/rfc7636#section-4.3
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -491,7 +491,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "missing PKCE code_challenge_method in request", // See https://tools.ietf.org/html/rfc7636#section-4.4.1
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -509,7 +509,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
// through that part of the fosite library.
name: "prompt param is not allowed to have none and another legal value at the same time",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -525,7 +525,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "OIDC validations are skipped when the openid scope was not requested",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -546,7 +546,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "state does not have enough entropy",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -562,7 +562,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "error while encoding upstream state param",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -577,7 +577,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "error while encoding CSRF cookie value for new cookie",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -592,7 +592,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "error while generating CSRF token",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: func() (csrftoken.CSRFToken, error) { return "", fmt.Errorf("some csrf generator error") },
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -607,7 +607,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "error while generating nonce",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: func() (nonce.Nonce, error) { return "", fmt.Errorf("some nonce generator error") },
@ -622,7 +622,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "error while generating PKCE",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some PKCE generator error") },
generateNonce: happyNonceGenerator,
@ -637,7 +637,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "error while decoding CSRF cookie",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
generateCSRF: happyCSRFGenerator,
generatePKCE: happyPKCEGenerator,
generateNonce: happyNonceGenerator,
@ -653,7 +653,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "no upstream providers are configured",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(), // empty
idpListGetter: oidctestutil.NewIDPListGetter(), // empty
method: http.MethodGet,
path: happyGetRequestPath,
wantStatus: http.StatusUnprocessableEntity,
@ -663,7 +663,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "too many upstream providers are configured",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider, &upstreamOIDCIdentityProvider), // more than one not allowed
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider, &upstreamOIDCIdentityProvider), // more than one not allowed
method: http.MethodGet,
path: happyGetRequestPath,
wantStatus: http.StatusUnprocessableEntity,
@ -673,7 +673,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "PUT is a bad method",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
method: http.MethodPut,
path: "/some/path",
wantStatus: http.StatusMethodNotAllowed,
@ -683,7 +683,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "PATCH is a bad method",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
method: http.MethodPatch,
path: "/some/path",
wantStatus: http.StatusMethodNotAllowed,
@ -693,7 +693,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
{
name: "DELETE is a bad method",
issuer: downstreamIssuer,
idpListGetter: testutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
idpListGetter: oidctestutil.NewIDPListGetter(&upstreamOIDCIdentityProvider),
method: http.MethodDelete,
path: "/some/path",
wantStatus: http.StatusMethodNotAllowed,
@ -772,7 +772,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
runOneTestCase(t, test, subject)
// Call the setter to change the upstream IDP settings.
newProviderSettings := testutil.TestUpstreamOIDCIdentityProvider{
newProviderSettings := oidctestutil.TestUpstreamOIDCIdentityProvider{
Name: "some-other-idp",
ClientID: "some-other-client-id",
AuthorizationURL: *upstreamAuthURL,
@ -840,13 +840,13 @@ func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL
expectedQueryStateParam := expectedLocationURL.Query().Get("state")
require.NotEmpty(t, expectedQueryStateParam)
var expectedDecodedStateParam testutil.ExpectedUpstreamStateParamFormat
var expectedDecodedStateParam oidctestutil.ExpectedUpstreamStateParamFormat
err = stateParamDecoder.Decode("s", expectedQueryStateParam, &expectedDecodedStateParam)
require.NoError(t, err)
actualQueryStateParam := actualLocationURL.Query().Get("state")
require.NotEmpty(t, actualQueryStateParam)
var actualDecodedStateParam testutil.ExpectedUpstreamStateParamFormat
var actualDecodedStateParam oidctestutil.ExpectedUpstreamStateParamFormat
err = stateParamDecoder.Decode("s", actualQueryStateParam, &actualDecodedStateParam)
require.NoError(t, err)

View File

@ -5,6 +5,7 @@
package callback
import (
"fmt"
"net/http"
"net/url"
"path"
@ -22,13 +23,22 @@ import (
)
const (
// The name of the issuer claim specified in the OIDC spec.
idTokenIssuerClaim = "iss"
// The name of the subject claim specified in the OIDC spec.
idTokenSubjectClaim = "sub"
// defaultUpstreamUsernameClaim is what we will use to extract the username from an upstream OIDC
// ID token if the upstream OIDC IDP did not tell us to use another claim.
defaultUpstreamUsernameClaim = "sub"
defaultUpstreamUsernameClaim = idTokenSubjectClaim
// downstreamGroupsClaim is what we will use to encode the groups in the downstream OIDC ID token
// information.
downstreamGroupsClaim = "groups"
// The lifetime of an issued downstream ID token.
downstreamIDTokenLifetime = time.Minute * 5
)
func NewHandler(
@ -90,7 +100,8 @@ func NewHandler(
openIDSession := makeDownstreamSession(downstreamIssuer, downstreamAuthParams.Get("client_id"), username, groups)
authorizeResponder, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, openIDSession)
if err != nil {
panic(err) // TODO
plog.WarningErr("error while generating and saving authcode", err, "upstreamName", upstreamIDPConfig.GetName())
return httperr.Wrap(http.StatusInternalServerError, "error while generating and saving authcode", err)
}
oauthHelper.WriteAuthorizeResponse(w, authorizeRequester, authorizeResponder)
@ -194,9 +205,30 @@ func getUsernameFromUpstreamIDToken(
idTokenClaims map[string]interface{},
) (string, error) {
usernameClaim := upstreamIDPConfig.GetUsernameClaim()
user := ""
if usernameClaim == "" {
// TODO: if we use the default "sub" claim, maybe we should create the username with the issuer
// since the spec says the "sub" claim is only unique per issuer.
// The spec says the "sub" claim is only unique per issuer, so by default when there is
// no specific username claim configured we will prepend the issuer string to make it globally unique.
upstreamIssuer := idTokenClaims[idTokenIssuerClaim]
if upstreamIssuer == "" {
plog.Warning(
"issuer claim in upstream ID token missing",
"upstreamName", upstreamIDPConfig.GetName(),
"issClaim", upstreamIssuer,
)
return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token missing")
}
upstreamIssuerAsString, ok := upstreamIssuer.(string)
if !ok {
plog.Warning(
"issuer claim in upstream ID token has invalid format",
"upstreamName", upstreamIDPConfig.GetName(),
"issClaim", upstreamIssuer,
)
return "", httperr.New(http.StatusUnprocessableEntity, "issuer claim in upstream ID token has invalid format")
}
user = fmt.Sprintf("%s?%s=", upstreamIssuerAsString, idTokenSubjectClaim)
usernameClaim = defaultUpstreamUsernameClaim
}
@ -222,7 +254,7 @@ func getUsernameFromUpstreamIDToken(
return "", httperr.New(http.StatusUnprocessableEntity, "username claim in upstream ID token has invalid format")
}
return username, nil
return fmt.Sprintf("%s%s", user, username), nil
}
func getGroupsFromUpstreamIDToken(
@ -266,10 +298,10 @@ func makeDownstreamSession(issuer, clientID, username string, groups []string) *
Issuer: issuer,
Subject: username,
Audience: []string{clientID},
ExpiresAt: now.Add(time.Minute * 30), // TODO use the right value here
IssuedAt: now, // TODO test this
RequestedAt: now, // TODO test this
AuthTime: now, // TODO test this
ExpiresAt: now.Add(downstreamIDTokenLifetime),
IssuedAt: now,
RequestedAt: now,
AuthTime: now,
},
}
if groups != nil {

View File

@ -13,6 +13,7 @@ import (
"regexp"
"strings"
"testing"
"time"
"github.com/gorilla/securecookie"
"github.com/ory/fosite"
@ -21,6 +22,7 @@ import (
"github.com/stretchr/testify/require"
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/oidctestutil"
"go.pinniped.dev/internal/oidcclient"
"go.pinniped.dev/internal/oidcclient/nonce"
"go.pinniped.dev/internal/oidcclient/pkce"
@ -30,26 +32,46 @@ import (
const (
happyUpstreamIDPName = "upstream-idp-name"
upstreamIssuer = "https://my-upstream-issuer.com"
upstreamSubject = "abc123-some-guid"
upstreamUsername = "test-pinniped-username"
upstreamUsernameClaim = "the-user-claim"
upstreamGroupsClaim = "the-groups-claim"
happyDownstreamState = "some-downstream-state"
happyCSRF = "test-csrf"
happyPKCE = "test-pkce"
happyNonce = "test-nonce"
happyStateVersion = "1"
downstreamIssuer = "https://my-downstream-issuer.com/path"
happyUpstreamAuthcode = "upstream-auth-code"
downstreamRedirectURI = "http://127.0.0.1/callback"
downstreamClientID = "pinniped-cli"
timeComparisonFudgeFactor = time.Second * 15
)
var (
upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"}
upstreamGroupMembership = []string{"test-pinniped-group-0", "test-pinniped-group-1"}
happyDownstreamScopesRequested = []string{"openid", "profile", "email"}
happyOriginalRequestParamsQuery = url.Values{
"response_type": []string{"code"},
"scope": []string{strings.Join(happyDownstreamScopesRequested, " ")},
"client_id": []string{downstreamClientID},
"state": []string{happyDownstreamState},
"nonce": []string{"some-nonce-value"},
"code_challenge": []string{"some-challenge"},
"code_challenge_method": []string{"S256"},
"redirect_uri": []string{downstreamRedirectURI},
}
happyOriginalRequestParams = happyOriginalRequestParamsQuery.Encode()
)
func TestCallbackEndpoint(t *testing.T) {
const (
downstreamIssuer = "https://my-downstream-issuer.com/path"
downstreamRedirectURI = "http://127.0.0.1/callback"
happyUpstreamAuthcode = "upstream-auth-code"
downstreamClientID = "pinniped-cli"
)
otherUpstreamOIDCIdentityProvider := testutil.TestUpstreamOIDCIdentityProvider{
otherUpstreamOIDCIdentityProvider := oidctestutil.TestUpstreamOIDCIdentityProvider{
Name: "other-upstream-idp-name",
ClientID: "other-some-client-id",
Scopes: []string{"other-scope1", "other-scope2"},
@ -67,95 +89,13 @@ func TestCallbackEndpoint(t *testing.T) {
var happyCookieCodec = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey)
happyCookieCodec.SetSerializer(securecookie.JSONEncoder{})
happyDownstreamState := "some-downstream-state"
happyOriginalRequestParamsQuery := url.Values{
"response_type": []string{"code"},
"scope": []string{"openid profile email"},
"client_id": []string{downstreamClientID},
"state": []string{happyDownstreamState},
"nonce": []string{"some-nonce-value"},
"code_challenge": []string{"some-challenge"},
"code_challenge_method": []string{"S256"},
"redirect_uri": []string{downstreamRedirectURI},
}
happyOriginalRequestParams := happyOriginalRequestParamsQuery.Encode()
happyCSRF := "test-csrf"
happyPKCE := "test-pkce"
happyNonce := "test-nonce"
happyStateVersion := "1"
happyState, err := happyStateCodec.Encode("s",
testutil.ExpectedUpstreamStateParamFormat{
P: happyOriginalRequestParams,
N: happyNonce,
C: happyCSRF,
K: happyPKCE,
V: happyStateVersion,
},
)
require.NoError(t, err)
wrongCSRFValueState, err := happyStateCodec.Encode("s",
testutil.ExpectedUpstreamStateParamFormat{
P: happyOriginalRequestParams,
N: happyNonce,
C: "wrong-csrf-value",
K: happyPKCE,
V: happyStateVersion,
},
)
require.NoError(t, err)
wrongVersionState, err := happyStateCodec.Encode("s",
testutil.ExpectedUpstreamStateParamFormat{
P: happyOriginalRequestParams,
N: happyNonce,
C: happyCSRF,
K: happyPKCE,
V: "wrong-state-version",
},
)
require.NoError(t, err)
wrongDownstreamAuthParamsState, err := happyStateCodec.Encode("s",
testutil.ExpectedUpstreamStateParamFormat{
P: "these-is-not-a-valid-url-query-%z",
N: happyNonce,
C: happyCSRF,
K: happyPKCE,
V: happyStateVersion,
},
)
require.NoError(t, err)
missingClientIDState, err := happyStateCodec.Encode("s",
testutil.ExpectedUpstreamStateParamFormat{
P: shallowCopyAndModifyQuery(happyOriginalRequestParamsQuery, map[string]string{"client_id": ""}).Encode(),
N: happyNonce,
C: happyCSRF,
K: happyPKCE,
V: happyStateVersion,
},
)
require.NoError(t, err)
noOpenidScopeState, err := happyStateCodec.Encode("s",
testutil.ExpectedUpstreamStateParamFormat{
P: shallowCopyAndModifyQuery(happyOriginalRequestParamsQuery, map[string]string{"scope": "profile email"}).Encode(),
N: happyNonce,
C: happyCSRF,
K: happyPKCE,
V: happyStateVersion,
},
)
require.NoError(t, err)
happyState := happyUpstreamStateParam().Build(t, happyStateCodec)
encodedIncomingCookieCSRFValue, err := happyCookieCodec.Encode("csrf", happyCSRF)
require.NoError(t, err)
happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue
happyExchangeAndValidateTokensArgs := &testutil.ExchangeAuthcodeAndValidateTokenArgs{
happyExchangeAndValidateTokensArgs := &oidctestutil.ExchangeAuthcodeAndValidateTokenArgs{
Authcode: happyUpstreamAuthcode,
PKCECodeVerifier: pkce.Code(happyPKCE),
ExpectedIDTokenNonce: nonce.Nonce(happyNonce),
@ -167,25 +107,26 @@ func TestCallbackEndpoint(t *testing.T) {
tests := []struct {
name string
idp testutil.TestUpstreamOIDCIdentityProvider
idp oidctestutil.TestUpstreamOIDCIdentityProvider
method string
path string
csrfCookie string
wantStatus int
wantBody string
wantRedirectLocationRegexp string
wantGrantedOpenidScope bool
wantDownstreamIDTokenSubject string
wantDownstreamIDTokenGroups []string
wantStatus int
wantBody string
wantRedirectLocationRegexp string
wantGrantedOpenidScope bool
wantDownstreamIDTokenSubject string
wantDownstreamIDTokenGroups []string
wantDownstreamRequestedScopes []string
wantExchangeAndValidateTokensCall *testutil.ExchangeAuthcodeAndValidateTokenArgs
wantExchangeAndValidateTokensCall *oidctestutil.ExchangeAuthcodeAndValidateTokenArgs
}{
{
name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(),
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound,
wantRedirectLocationRegexp: happyRedirectLocationRegexp,
@ -193,22 +134,39 @@ func TestCallbackEndpoint(t *testing.T) {
wantBody: "",
wantDownstreamIDTokenSubject: upstreamUsername,
wantDownstreamIDTokenGroups: upstreamGroupMembership,
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
name: "upstream IDP provides no username or group claim, so we use default username claim and skip groups",
name: "upstream IDP provides no username or group claim configuration, so we use default username claim and skip groups",
idp: happyUpstream().WithoutUsernameClaim().WithoutGroupsClaim().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(),
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound,
wantRedirectLocationRegexp: happyRedirectLocationRegexp,
wantGrantedOpenidScope: true,
wantBody: "",
wantDownstreamIDTokenSubject: upstreamIssuer + "?sub=" + upstreamSubject,
wantDownstreamIDTokenGroups: nil,
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
name: "upstream IDP provides username claim configuration as `sub`, so the downstream token subject should be exactly what they asked for",
idp: happyUpstream().WithUsernameClaim("sub").Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound,
wantRedirectLocationRegexp: happyRedirectLocationRegexp,
wantGrantedOpenidScope: true,
wantBody: "",
wantDownstreamIDTokenSubject: upstreamSubject,
wantDownstreamIDTokenGroups: upstreamGroupMembership,
wantDownstreamRequestedScopes: happyDownstreamScopesRequested,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
// TODO: when we call the callback twice in a row, we get two different auth codes (to prove we are using an RNG for auth codes)
// Pre-upstream-exchange verification
{
@ -264,42 +222,70 @@ func TestCallbackEndpoint(t *testing.T) {
wantStatus: http.StatusBadRequest,
wantBody: "Bad Request: error reading state\n",
},
{
// This shouldn't happen in practice because the authorize endpoint should have already run the same
// validations, but we would like to test the error handling in this endpoint anyway.
name: "state param contains authorization request params which fail validation",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(
happyUpstreamStateParam().
WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyOriginalRequestParamsQuery, map[string]string{"prompt": "none login"}).Encode()).
Build(t, happyStateCodec),
).String(),
csrfCookie: happyCSRFCookie,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
wantStatus: http.StatusInternalServerError,
wantBody: "Internal Server Error: error while generating and saving authcode\n",
},
{
name: "state's internal version does not match what we want",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(wrongVersionState).String(),
path: newRequestPath().WithState(happyUpstreamStateParam().WithStateVersion("wrong-state-version").Build(t, happyStateCodec)).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: state format version is invalid\n",
},
{
name: "state's downstream auth params element is invalid",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(wrongDownstreamAuthParamsState).String(),
name: "state's downstream auth params element is invalid",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyUpstreamStateParam().
WithAuthorizeRequestParams("the following is an invalid url encoding token, and therefore this is an invalid param: %z").
Build(t, happyStateCodec)).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest,
wantBody: "Bad Request: error reading state downstream auth params\n",
},
{
name: "state's downstream auth params are missing required value (e.g., client_id)",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(missingClientIDState).String(),
name: "state's downstream auth params are missing required value (e.g., client_id)",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(
happyUpstreamStateParam().
WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyOriginalRequestParamsQuery, map[string]string{"client_id": ""}).Encode()).
Build(t, happyStateCodec),
).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest,
wantBody: "Bad Request: error using state downstream auth params\n",
},
{
name: "state's downstream auth params does not contain openid scope",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(noOpenidScopeState).WithCode(happyUpstreamAuthcode).String(),
name: "state's downstream auth params does not contain openid scope",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().
WithState(
happyUpstreamStateParam().
WithAuthorizeRequestParams(shallowCopyAndModifyQuery(happyOriginalRequestParamsQuery, map[string]string{"scope": "profile email"}).Encode()).
Build(t, happyStateCodec),
).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound,
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&scope=&state=` + happyDownstreamState,
wantDownstreamIDTokenSubject: upstreamUsername,
wantDownstreamRequestedScopes: []string{"profile", "email"},
wantDownstreamIDTokenGroups: upstreamGroupMembership,
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
@ -333,7 +319,7 @@ func TestCallbackEndpoint(t *testing.T) {
name: "cookie csrf value does not match state csrf value",
idp: happyUpstream().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(wrongCSRFValueState).String(),
path: newRequestPath().WithState(happyUpstreamStateParam().WithCSRF("wrong-csrf-value").Build(t, happyStateCodec)).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusForbidden,
wantBody: "Forbidden: CSRF value does not match\n",
@ -344,7 +330,7 @@ func TestCallbackEndpoint(t *testing.T) {
name: "upstream auth code exchange fails",
idp: happyUpstream().WithoutUpstreamAuthcodeExchangeError(errors.New("some error")).Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(),
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadGateway,
wantBody: "Bad Gateway: error exchanging and validating upstream tokens\n",
@ -354,7 +340,7 @@ func TestCallbackEndpoint(t *testing.T) {
name: "upstream ID token does not contain requested username claim",
idp: happyUpstream().WithoutIDTokenClaim(upstreamUsernameClaim).Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(),
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: no username claim in upstream ID token\n",
@ -364,7 +350,7 @@ func TestCallbackEndpoint(t *testing.T) {
name: "upstream ID token does not contain requested groups claim",
idp: happyUpstream().WithoutIDTokenClaim(upstreamGroupsClaim).Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(),
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: no groups claim in upstream ID token\n",
@ -374,17 +360,37 @@ func TestCallbackEndpoint(t *testing.T) {
name: "upstream ID token contains username claim with weird format",
idp: happyUpstream().WithIDTokenClaim(upstreamUsernameClaim, 42).Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(),
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: username claim in upstream ID token has invalid format\n",
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
name: "upstream ID token does not contain iss claim when using default username claim config",
idp: happyUpstream().WithIDTokenClaim("iss", "").WithoutUsernameClaim().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: issuer claim in upstream ID token missing\n",
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
name: "upstream ID token has an non-string iss claim when using default username claim config",
idp: happyUpstream().WithIDTokenClaim("iss", 42).WithoutUsernameClaim().Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: issuer claim in upstream ID token has invalid format\n",
wantExchangeAndValidateTokensCall: happyExchangeAndValidateTokensArgs,
},
{
name: "upstream ID token contains groups claim with weird format",
idp: happyUpstream().WithIDTokenClaim(upstreamGroupsClaim, 42).Build(),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).WithCode(happyUpstreamAuthcode).String(),
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusUnprocessableEntity,
wantBody: "Unprocessable Entity: groups claim in upstream ID token has invalid format\n",
@ -393,6 +399,7 @@ func TestCallbackEndpoint(t *testing.T) {
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
// Configure fosite the same way that the production code would, except use in-memory storage.
// Inject this into our test subject at the last second so we get a fresh storage for every test.
@ -406,7 +413,7 @@ func TestCallbackEndpoint(t *testing.T) {
require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes")
oauthHelper := oidc.FositeOauth2Helper(oauthStore, hmacSecret)
idpListGetter := testutil.NewIDPListGetter(&test.idp)
idpListGetter := oidctestutil.NewIDPListGetter(&test.idp)
subject := NewHandler(downstreamIssuer, idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec)
req := httptest.NewRequest(test.method, test.path, nil)
if test.csrfCookie != "" {
@ -433,7 +440,7 @@ func TestCallbackEndpoint(t *testing.T) {
require.Empty(t, rsp.Body.String())
}
if test.wantRedirectLocationRegexp != "" {
if test.wantRedirectLocationRegexp != "" { //nolint:nestif // don't mind have several sequential if statements in this test
// Assert that Location header matches regular expression.
require.Len(t, rsp.Header().Values("Location"), 1)
actualLocation := rsp.Header().Get("Location")
@ -459,20 +466,63 @@ func TestCallbackEndpoint(t *testing.T) {
storedSession, ok := storedAuthorizeRequest.GetSession().(*openid.DefaultSession)
require.True(t, ok)
// Check various fields of the stored data.
// Check which scopes were granted.
if test.wantGrantedOpenidScope {
require.Contains(t, storedRequest.GetGrantedScopes(), "openid")
} else {
require.NotContains(t, storedRequest.GetGrantedScopes(), "openid")
}
require.Equal(t, downstreamIssuer, storedSession.Claims.Issuer)
require.Equal(t, test.wantDownstreamIDTokenSubject, storedSession.Claims.Subject)
require.Equal(t, []string{downstreamClientID}, storedSession.Claims.Audience)
// Check all the other fields of the stored request.
require.NotEmpty(t, storedRequest.ID)
require.Equal(t, downstreamClientID, storedRequest.Client.GetID())
require.ElementsMatch(t, test.wantDownstreamRequestedScopes, storedRequest.RequestedScope)
require.Nil(t, storedRequest.RequestedAudience)
require.Empty(t, storedRequest.GrantedAudience)
require.Equal(t, url.Values{"redirect_uri": []string{downstreamRedirectURI}}, storedRequest.Form)
testutil.RequireTimeInDelta(t, time.Now(), storedRequest.RequestedAt, timeComparisonFudgeFactor)
// We're not using these fields yet, so confirm that we did not set them (for now).
require.Empty(t, storedSession.Subject)
require.Empty(t, storedSession.Username)
require.Empty(t, storedSession.Headers)
// The authcode that we are issuing should be good for 15 minutes, which is default for fosite.
testutil.RequireTimeInDelta(t, time.Now().Add(time.Minute*15), storedSession.ExpiresAt[fosite.AuthorizeCode], timeComparisonFudgeFactor)
require.Len(t, storedSession.ExpiresAt, 1)
// Now confirm the ID token claims.
actualClaims := storedSession.Claims
// Check the user's identity, which are put into the downstream ID token's subject and groups claims.
require.Equal(t, test.wantDownstreamIDTokenSubject, actualClaims.Subject)
if test.wantDownstreamIDTokenGroups != nil {
require.Equal(t, test.wantDownstreamIDTokenGroups, storedSession.Claims.Extra["groups"])
require.Len(t, actualClaims.Extra, 1)
require.Equal(t, test.wantDownstreamIDTokenGroups, actualClaims.Extra["groups"])
} else {
require.NotContains(t, storedSession.Claims.Extra, "groups")
require.Empty(t, actualClaims.Extra)
require.NotContains(t, actualClaims.Extra, "groups")
}
// Check the rest of the downstream ID token's claims.
require.Equal(t, downstreamIssuer, actualClaims.Issuer)
require.Equal(t, []string{downstreamClientID}, actualClaims.Audience)
testutil.RequireTimeInDelta(t, time.Now().Add(time.Minute*5), actualClaims.ExpiresAt, timeComparisonFudgeFactor)
testutil.RequireTimeInDelta(t, time.Now(), actualClaims.IssuedAt, timeComparisonFudgeFactor)
testutil.RequireTimeInDelta(t, time.Now(), actualClaims.RequestedAt, timeComparisonFudgeFactor)
testutil.RequireTimeInDelta(t, time.Now(), actualClaims.AuthTime, timeComparisonFudgeFactor)
// These are not needed yet.
require.Empty(t, actualClaims.JTI)
require.Empty(t, actualClaims.CodeHash)
require.Empty(t, actualClaims.AccessTokenHash)
require.Empty(t, actualClaims.AuthenticationContextClassReference)
require.Empty(t, actualClaims.AuthenticationMethodsReference)
// TODO we should put the downstream request's nonce into the ID token, but maybe the token endpoint is responsible for that?
require.Empty(t, actualClaims.Nonce)
// TODO add thorough tests about what should be stored for PKCES and IDSessions
} else {
require.Empty(t, rsp.Header().Values("Location"))
}
@ -486,7 +536,7 @@ type requestPath struct {
func newRequestPath() *requestPath {
n := happyUpstreamIDPName
c := "1234"
c := happyUpstreamAuthcode
s := "4321"
return &requestPath{
upstreamIDPName: &n,
@ -532,6 +582,49 @@ func (r *requestPath) String() string {
return path + params.Encode()
}
type upstreamStateParamBuilder oidctestutil.ExpectedUpstreamStateParamFormat
func happyUpstreamStateParam() *upstreamStateParamBuilder {
return &upstreamStateParamBuilder{
P: happyOriginalRequestParams,
N: happyNonce,
C: happyCSRF,
K: happyPKCE,
V: happyStateVersion,
}
}
func (b upstreamStateParamBuilder) Build(t *testing.T, stateEncoder *securecookie.SecureCookie) string {
state, err := stateEncoder.Encode("s", b)
require.NoError(t, err)
return state
}
func (b *upstreamStateParamBuilder) WithAuthorizeRequestParams(params string) *upstreamStateParamBuilder {
b.P = params
return b
}
func (b *upstreamStateParamBuilder) WithNonce(nonce string) *upstreamStateParamBuilder {
b.N = nonce
return b
}
func (b *upstreamStateParamBuilder) WithCSRF(csrf string) *upstreamStateParamBuilder {
b.C = csrf
return b
}
func (b *upstreamStateParamBuilder) WithPKCVE(pkce string) *upstreamStateParamBuilder {
b.K = pkce
return b
}
func (b *upstreamStateParamBuilder) WithStateVersion(version string) *upstreamStateParamBuilder {
b.V = version
return b
}
type upstreamOIDCIdentityProviderBuilder struct {
idToken map[string]interface{}
usernameClaim, groupsClaim string
@ -543,6 +636,7 @@ func happyUpstream() *upstreamOIDCIdentityProviderBuilder {
usernameClaim: upstreamUsernameClaim,
groupsClaim: upstreamGroupsClaim,
idToken: map[string]interface{}{
"iss": upstreamIssuer,
"sub": upstreamSubject,
upstreamUsernameClaim: upstreamUsername,
upstreamGroupsClaim: upstreamGroupMembership,
@ -551,6 +645,11 @@ func happyUpstream() *upstreamOIDCIdentityProviderBuilder {
}
}
func (u *upstreamOIDCIdentityProviderBuilder) WithUsernameClaim(claim string) *upstreamOIDCIdentityProviderBuilder {
u.usernameClaim = claim
return u
}
func (u *upstreamOIDCIdentityProviderBuilder) WithoutUsernameClaim() *upstreamOIDCIdentityProviderBuilder {
u.usernameClaim = ""
return u
@ -576,8 +675,8 @@ func (u *upstreamOIDCIdentityProviderBuilder) WithoutUpstreamAuthcodeExchangeErr
return u
}
func (u *upstreamOIDCIdentityProviderBuilder) Build() testutil.TestUpstreamOIDCIdentityProvider {
return testutil.TestUpstreamOIDCIdentityProvider{
func (u *upstreamOIDCIdentityProviderBuilder) Build() oidctestutil.TestUpstreamOIDCIdentityProvider {
return oidctestutil.TestUpstreamOIDCIdentityProvider{
Name: happyUpstreamIDPName,
ClientID: "some-client-id",
UsernameClaim: u.usernameClaim,
@ -592,12 +691,13 @@ func (u *upstreamOIDCIdentityProviderBuilder) Build() testutil.TestUpstreamOIDCI
func shallowCopyAndModifyQuery(query url.Values, modifications map[string]string) url.Values {
copied := url.Values{}
for key, value := range query {
if modification, ok := modifications[key]; ok {
if modification != "" {
copied[key] = []string{modification}
}
copied[key] = value
}
for key, value := range modifications {
if value == "" {
copied.Del(key)
} else {
copied[key] = value
copied[key] = []string{value}
}
}
return copied

View File

@ -1,7 +1,7 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package testutil
package oidctestutil
import (
"context"

View File

@ -12,8 +12,6 @@ import (
"strings"
"testing"
"go.pinniped.dev/internal/testutil"
"github.com/sclevine/spec"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2"
@ -22,6 +20,7 @@ import (
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/discovery"
"go.pinniped.dev/internal/oidc/jwks"
"go.pinniped.dev/internal/oidc/oidctestutil"
"go.pinniped.dev/internal/oidc/provider"
)
@ -109,7 +108,7 @@ func TestManager(t *testing.T) {
parsedUpstreamIDPAuthorizationURL, err := url.Parse(upstreamIDPAuthorizationURL)
r.NoError(err)
idpListGetter := testutil.NewIDPListGetter(&testutil.TestUpstreamOIDCIdentityProvider{
idpListGetter := oidctestutil.NewIDPListGetter(&oidctestutil.TestUpstreamOIDCIdentityProvider{
Name: "test-idp",
ClientID: "test-client-id",
AuthorizationURL: *parsedUpstreamIDPAuthorizationURL,

View File

@ -26,6 +26,7 @@ import (
"go.pinniped.dev/internal/oidcclient/nonce"
"go.pinniped.dev/internal/oidcclient/pkce"
"go.pinniped.dev/internal/oidcclient/state"
"go.pinniped.dev/internal/testutil"
)
// mockSessionCache exists to avoid an import cycle if we generate mocks into another package.
@ -481,7 +482,7 @@ func TestLogin(t *testing.T) {
require.NotNil(t, tok.AccessToken)
require.Equal(t, want.Token, tok.AccessToken.Token)
require.Equal(t, want.Type, tok.AccessToken.Type)
requireTimeInDelta(t, want.Expiry.Time, tok.AccessToken.Expiry.Time, 5*time.Second)
testutil.RequireTimeInDelta(t, want.Expiry.Time, tok.AccessToken.Expiry.Time, 5*time.Second)
} else {
assert.Nil(t, tok.AccessToken)
}
@ -489,7 +490,7 @@ func TestLogin(t *testing.T) {
if want := tt.wantToken.IDToken; want != nil {
require.NotNil(t, tok.IDToken)
require.Equal(t, want.Token, tok.IDToken.Token)
requireTimeInDelta(t, want.Expiry.Time, tok.IDToken.Expiry.Time, 5*time.Second)
testutil.RequireTimeInDelta(t, want.Expiry.Time, tok.IDToken.Expiry.Time, 5*time.Second)
} else {
assert.Nil(t, tok.IDToken)
}
@ -682,16 +683,3 @@ type mockDiscovery struct{ provider *oidc.Provider }
func (m *mockDiscovery) Endpoint() oauth2.Endpoint { return m.provider.Endpoint() }
func (m *mockDiscovery) Verifier(config *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() }
func requireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Duration) {
require.InDeltaf(t,
float64(t1.UnixNano()),
float64(t2.UnixNano()),
float64(delta.Nanoseconds()),
"expected %s and %s to be < %s apart, but they are %s apart",
t1.Format(time.RFC3339Nano),
t2.Format(time.RFC3339Nano),
delta.String(),
t1.Sub(t2).String(),
)
}

View File

@ -0,0 +1,24 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package testutil
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func RequireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Duration) {
require.InDeltaf(t,
float64(t1.UnixNano()),
float64(t2.UnixNano()),
float64(delta.Nanoseconds()),
"expected %s and %s to be < %s apart, but they are %s apart",
t1.Format(time.RFC3339Nano),
t2.Format(time.RFC3339Nano),
delta.String(),
t1.Sub(t2).String(),
)
}