Add custom prefix to downstream access and refresh tokens and authcodes
This commit is contained in:
parent
13daf59217
commit
53348b8464
@ -1,14 +1,22 @@
|
||||
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
|
||||
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/compose"
|
||||
"github.com/ory/fosite/handler/oauth2"
|
||||
errorsx "github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
accessTokenPrefix = "pin_at_" // "Pinniped access token" abbreviated.
|
||||
refreshTokenPrefix = "pin_rt_" // "Pinniped refresh token" abbreviated.
|
||||
authcodePrefix = "pin_ac_" // "Pinniped authorization code" abbreviated.
|
||||
)
|
||||
|
||||
// dynamicOauth2HMACStrategy is an oauth2.CoreStrategy that can dynamically load an HMAC key to sign
|
||||
@ -19,6 +27,9 @@ import (
|
||||
// If we ever update FederationDomain's to hold their signing key, we might not need this type, since we
|
||||
// could have an invariant that routes to an FederationDomain's endpoints are only wired up if an
|
||||
// FederationDomain has a valid signing key.
|
||||
//
|
||||
// Tokens start with a custom prefix to make them identifiable as tokens when seen by a user
|
||||
// out of context, such as when accidentally committed to a GitHub repo.
|
||||
type dynamicOauth2HMACStrategy struct {
|
||||
fositeConfig *compose.Config
|
||||
keyFunc func() []byte
|
||||
@ -44,7 +55,11 @@ func (s *dynamicOauth2HMACStrategy) GenerateAccessToken(
|
||||
ctx context.Context,
|
||||
requester fosite.Requester,
|
||||
) (token string, signature string, err error) {
|
||||
return s.delegate().GenerateAccessToken(ctx, requester)
|
||||
token, sig, err := s.delegate().GenerateAccessToken(ctx, requester)
|
||||
if err == nil {
|
||||
token = accessTokenPrefix + token
|
||||
}
|
||||
return token, sig, err
|
||||
}
|
||||
|
||||
func (s *dynamicOauth2HMACStrategy) ValidateAccessToken(
|
||||
@ -52,7 +67,11 @@ func (s *dynamicOauth2HMACStrategy) ValidateAccessToken(
|
||||
requester fosite.Requester,
|
||||
token string,
|
||||
) (err error) {
|
||||
return s.delegate().ValidateAccessToken(ctx, requester, token)
|
||||
if !strings.HasPrefix(token, accessTokenPrefix) {
|
||||
return errorsx.WithStack(fosite.ErrInvalidTokenFormat.
|
||||
WithDebugf("Access token did not have prefix %q", accessTokenPrefix))
|
||||
}
|
||||
return s.delegate().ValidateAccessToken(ctx, requester, token[len(accessTokenPrefix):])
|
||||
}
|
||||
|
||||
func (s *dynamicOauth2HMACStrategy) RefreshTokenSignature(token string) string {
|
||||
@ -63,7 +82,11 @@ func (s *dynamicOauth2HMACStrategy) GenerateRefreshToken(
|
||||
ctx context.Context,
|
||||
requester fosite.Requester,
|
||||
) (token string, signature string, err error) {
|
||||
return s.delegate().GenerateRefreshToken(ctx, requester)
|
||||
token, sig, err := s.delegate().GenerateRefreshToken(ctx, requester)
|
||||
if err == nil {
|
||||
token = refreshTokenPrefix + token
|
||||
}
|
||||
return token, sig, err
|
||||
}
|
||||
|
||||
func (s *dynamicOauth2HMACStrategy) ValidateRefreshToken(
|
||||
@ -71,7 +94,11 @@ func (s *dynamicOauth2HMACStrategy) ValidateRefreshToken(
|
||||
requester fosite.Requester,
|
||||
token string,
|
||||
) (err error) {
|
||||
return s.delegate().ValidateRefreshToken(ctx, requester, token)
|
||||
if !strings.HasPrefix(token, refreshTokenPrefix) {
|
||||
return errorsx.WithStack(fosite.ErrInvalidTokenFormat.
|
||||
WithDebugf("Refresh token did not have prefix %q", refreshTokenPrefix))
|
||||
}
|
||||
return s.delegate().ValidateRefreshToken(ctx, requester, token[len(refreshTokenPrefix):])
|
||||
}
|
||||
|
||||
func (s *dynamicOauth2HMACStrategy) AuthorizeCodeSignature(token string) string {
|
||||
@ -82,7 +109,11 @@ func (s *dynamicOauth2HMACStrategy) GenerateAuthorizeCode(
|
||||
ctx context.Context,
|
||||
requester fosite.Requester,
|
||||
) (token string, signature string, err error) {
|
||||
return s.delegate().GenerateAuthorizeCode(ctx, requester)
|
||||
authcode, sig, err := s.delegate().GenerateAuthorizeCode(ctx, requester)
|
||||
if err == nil {
|
||||
authcode = authcodePrefix + authcode
|
||||
}
|
||||
return authcode, sig, err
|
||||
}
|
||||
|
||||
func (s *dynamicOauth2HMACStrategy) ValidateAuthorizeCode(
|
||||
@ -90,7 +121,11 @@ func (s *dynamicOauth2HMACStrategy) ValidateAuthorizeCode(
|
||||
requester fosite.Requester,
|
||||
token string,
|
||||
) (err error) {
|
||||
return s.delegate().ValidateAuthorizeCode(ctx, requester, token)
|
||||
if !strings.HasPrefix(token, authcodePrefix) {
|
||||
return errorsx.WithStack(fosite.ErrInvalidTokenFormat.
|
||||
WithDebugf("Authorization code did not have prefix %q", authcodePrefix))
|
||||
}
|
||||
return s.delegate().ValidateAuthorizeCode(ctx, requester, token[len(authcodePrefix):])
|
||||
}
|
||||
|
||||
func (s *dynamicOauth2HMACStrategy) delegate() *oauth2.HMACSHAStrategy {
|
||||
|
218
internal/oidc/dynamic_oauth2_hmac_strategy_test.go
Normal file
218
internal/oidc/dynamic_oauth2_hmac_strategy_test.go
Normal file
@ -0,0 +1,218 @@
|
||||
// Copyright 2022 the Pinniped contributors. All Rights Reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/fosite/compose"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDynamicOauth2HMACStrategy_Signatures(t *testing.T) {
|
||||
s := &dynamicOauth2HMACStrategy{
|
||||
fositeConfig: &compose.Config{}, // defaults are good enough for this unit test
|
||||
keyFunc: func() []byte { return []byte("12345678901234567890123456789012") },
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
signatureFunc func(token string) (signature string)
|
||||
wantSignature string
|
||||
}{
|
||||
{
|
||||
name: "access token signature is the part after the dot in the default HMAC strategy",
|
||||
token: "token.signature",
|
||||
signatureFunc: s.AccessTokenSignature,
|
||||
wantSignature: "signature",
|
||||
},
|
||||
{
|
||||
name: "refresh token signature is the part after the dot in the default HMAC strategy",
|
||||
token: "token.signature",
|
||||
signatureFunc: s.RefreshTokenSignature,
|
||||
wantSignature: "signature",
|
||||
},
|
||||
{
|
||||
name: "authcode signature is the part after the dot in the default HMAC strategy",
|
||||
token: "token.signature",
|
||||
signatureFunc: s.AuthorizeCodeSignature,
|
||||
wantSignature: "signature",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
require.Equal(t, tt.wantSignature, tt.signatureFunc(tt.token))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicOauth2HMACStrategy_Generate(t *testing.T) {
|
||||
s := &dynamicOauth2HMACStrategy{
|
||||
fositeConfig: &compose.Config{}, // defaults are good enough for this unit test
|
||||
keyFunc: func() []byte { return []byte("12345678901234567890123456789012") }, // 32 character secret key
|
||||
}
|
||||
|
||||
generateTokenErrorCausingStrategy := &dynamicOauth2HMACStrategy{
|
||||
fositeConfig: &compose.Config{},
|
||||
keyFunc: func() []byte { return []byte("too_short_causes_error") }, // secret key is below required 32 characters
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
generateFunc func(ctx context.Context, requester fosite.Requester) (token string, signature string, err error)
|
||||
errGenerateFunc func(ctx context.Context, requester fosite.Requester) (token string, signature string, err error)
|
||||
wantPrefix string
|
||||
}{
|
||||
{
|
||||
name: "access tokens are base64 random bytes followed by dot followed by base64 signature of the random bytes in the default HMAC strategy",
|
||||
generateFunc: s.GenerateAccessToken,
|
||||
errGenerateFunc: generateTokenErrorCausingStrategy.GenerateAccessToken,
|
||||
wantPrefix: "pin_at_",
|
||||
},
|
||||
{
|
||||
name: "refresh tokens are base64 random bytes followed by dot followed by base64 signature of the random bytes in the default HMAC strategy",
|
||||
generateFunc: s.GenerateRefreshToken,
|
||||
errGenerateFunc: generateTokenErrorCausingStrategy.GenerateRefreshToken,
|
||||
wantPrefix: "pin_rt_",
|
||||
},
|
||||
{
|
||||
name: "authcodes are base64 random bytes followed by dot followed by base64 signature of the random bytes in the default HMAC strategy",
|
||||
generateFunc: s.GenerateAuthorizeCode,
|
||||
errGenerateFunc: generateTokenErrorCausingStrategy.GenerateAuthorizeCode,
|
||||
wantPrefix: "pin_ac_",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
requireRandomTokenInExpectedFormat := func(token, signature string) {
|
||||
// Tokens should start with a custom prefix to make them identifiable as tokens when seen by a user
|
||||
// out of context, such as when accidentally committed to a GitHub repo.
|
||||
require.True(t, strings.HasPrefix(token, tt.wantPrefix), "token %q did not have expected prefix %q", token, tt.wantPrefix)
|
||||
require.Equal(t, 1, strings.Count(token, "."))
|
||||
require.Len(t, signature, 43)
|
||||
require.True(t, strings.HasSuffix(token, "."+signature), "token %q did not end with dot followed by signature", token)
|
||||
// The part before the dot is the prefix plus 43 characters of base64 encoded random bytes.
|
||||
require.Len(t, strings.Split(token, ".")[0], len(tt.wantPrefix)+43)
|
||||
}
|
||||
|
||||
var ctxIsIgnored context.Context
|
||||
var requesterIsIgnored fosite.Requester
|
||||
|
||||
generatedToken1, signature1, err := tt.generateFunc(ctxIsIgnored, requesterIsIgnored)
|
||||
require.NoError(t, err)
|
||||
requireRandomTokenInExpectedFormat(generatedToken1, signature1)
|
||||
|
||||
generatedToken2, signature2, err := tt.generateFunc(ctxIsIgnored, requesterIsIgnored)
|
||||
require.NoError(t, err)
|
||||
requireRandomTokenInExpectedFormat(generatedToken2, signature2)
|
||||
|
||||
// Each generated token is random/different.
|
||||
require.NotEqual(t, generatedToken1, generatedToken2)
|
||||
require.NotEqual(t, signature1, signature2)
|
||||
|
||||
// Test the return values when an error is encountered during generation.
|
||||
generatedToken3, signature3, err := tt.errGenerateFunc(ctxIsIgnored, requesterIsIgnored)
|
||||
require.EqualError(t, err, "secret for signing HMAC-SHA512/256 is expected to be 32 byte long, got 22 byte")
|
||||
require.Empty(t, generatedToken3)
|
||||
require.Empty(t, signature3)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDynamicOauth2HMACStrategy_Validate(t *testing.T) {
|
||||
s := &dynamicOauth2HMACStrategy{
|
||||
fositeConfig: &compose.Config{}, // defaults are good enough for this unit test
|
||||
keyFunc: func() []byte { return []byte("12345678901234567890123456789012") },
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
generateFunc func(ctx context.Context, requester fosite.Requester) (token string, signature string, err error)
|
||||
validateFunc func(ctx context.Context, requester fosite.Requester, token string) error
|
||||
wantPrefix string
|
||||
}{
|
||||
{
|
||||
name: "access tokens",
|
||||
generateFunc: s.GenerateAccessToken,
|
||||
validateFunc: s.ValidateAccessToken,
|
||||
wantPrefix: "pin_at_",
|
||||
},
|
||||
{
|
||||
name: "refresh tokens",
|
||||
generateFunc: s.GenerateRefreshToken,
|
||||
validateFunc: s.ValidateRefreshToken,
|
||||
wantPrefix: "pin_rt_",
|
||||
},
|
||||
{
|
||||
name: "authcodes",
|
||||
generateFunc: s.GenerateAuthorizeCode,
|
||||
validateFunc: s.ValidateAuthorizeCode,
|
||||
wantPrefix: "pin_ac_",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var ctxIsIgnored context.Context
|
||||
var requesterIsIgnored fosite.Requester
|
||||
|
||||
unexpiredSession := &fosite.DefaultSession{}
|
||||
unexpiredSession.SetExpiresAt(fosite.RefreshToken, time.Now().Add(time.Hour))
|
||||
unexpiredSession.SetExpiresAt(fosite.AccessToken, time.Now().Add(time.Hour))
|
||||
unexpiredSession.SetExpiresAt(fosite.AuthorizeCode, time.Now().Add(time.Hour))
|
||||
requesterWithUnexpiredTokens := &fosite.Request{Session: unexpiredSession}
|
||||
|
||||
expiredSession := &fosite.DefaultSession{}
|
||||
expiredSession.SetExpiresAt(fosite.RefreshToken, time.Now().Add(-time.Hour))
|
||||
expiredSession.SetExpiresAt(fosite.AccessToken, time.Now().Add(-time.Hour))
|
||||
expiredSession.SetExpiresAt(fosite.AuthorizeCode, time.Now().Add(-time.Hour))
|
||||
requesterWithExpiredTokens := &fosite.Request{Session: expiredSession}
|
||||
|
||||
generatedToken, _, err := tt.generateFunc(ctxIsIgnored, requesterIsIgnored)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tt.validateFunc(ctxIsIgnored, requesterWithUnexpiredTokens, generatedToken))
|
||||
|
||||
generatedToken, _, err = tt.generateFunc(ctxIsIgnored, requesterIsIgnored)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, tt.validateFunc(ctxIsIgnored, requesterWithUnexpiredTokens, generatedToken))
|
||||
|
||||
// Generated token has prefix.
|
||||
require.True(t, strings.HasPrefix(generatedToken, tt.wantPrefix), "token %q did not have expected prefix %q", generatedToken, tt.wantPrefix)
|
||||
|
||||
// Validate when expired according to session.
|
||||
require.EqualError(t, tt.validateFunc(ctxIsIgnored, requesterWithExpiredTokens, generatedToken), "invalid_token")
|
||||
|
||||
// Validate when missing prefix.
|
||||
require.EqualError(t, tt.validateFunc(ctxIsIgnored, requesterWithUnexpiredTokens, strings.TrimPrefix(generatedToken, tt.wantPrefix)), "invalid_token")
|
||||
|
||||
// Validate when wrong prefix.
|
||||
require.EqualError(t, tt.validateFunc(ctxIsIgnored, requesterWithUnexpiredTokens, "pin_wrong_"+strings.TrimPrefix(generatedToken, tt.wantPrefix)), "invalid_token")
|
||||
|
||||
// Validate when correct prefix but otherwise invalid format.
|
||||
require.EqualError(t, tt.validateFunc(ctxIsIgnored, requesterWithUnexpiredTokens, tt.wantPrefix+"illegal token"), "invalid_token")
|
||||
|
||||
// Validate when correct prefix but bad signature.
|
||||
var b64 = base64.URLEncoding.WithPadding(base64.NoPadding)
|
||||
tokenWithBadSig := tt.wantPrefix + b64.EncodeToString([]byte("some-token")) + "." + b64.EncodeToString([]byte("bad-signature"))
|
||||
require.EqualError(t, tt.validateFunc(ctxIsIgnored, requesterWithUnexpiredTokens, tokenWithBadSig), "token_signature_mismatch")
|
||||
})
|
||||
}
|
||||
}
|
@ -3251,6 +3251,9 @@ func requireValidRefreshTokenStorage(
|
||||
storedRequest, err := storage.GetRefreshTokenSession(context.Background(), getFositeDataSignature(t, refreshTokenString), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Refresh tokens should start with the custom prefix "pin_rt_" to make them identifiable as refresh tokens when seen by a user out of context.
|
||||
require.True(t, strings.HasPrefix(refreshTokenString, "pin_rt_"), "token %q did not have expected prefix 'pin_rt_'", refreshTokenString)
|
||||
|
||||
// Fosite stores refresh tokens without any of the original request form parameters.
|
||||
requireValidStoredRequest(
|
||||
t,
|
||||
@ -3287,6 +3290,9 @@ func requireValidAccessTokenStorage(
|
||||
storedRequest, err := storage.GetAccessTokenSession(context.Background(), getFositeDataSignature(t, accessTokenString), nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Access tokens should start with the custom prefix "pin_at_" to make them identifiable as access tokens when seen by a user out of context.
|
||||
require.True(t, strings.HasPrefix(accessTokenString, "pin_at_"), "token %q did not have expected prefix 'pin_at_'", accessTokenString)
|
||||
|
||||
// Make sure the other body fields are valid.
|
||||
tokenType, ok := body["token_type"]
|
||||
require.True(t, ok)
|
||||
|
@ -901,6 +901,9 @@ func RequireAuthCodeRegexpMatch(
|
||||
require.Lenf(t, submatches, 2, "no regexp match in actualContent: %", actualContent)
|
||||
capturedAuthCode := submatches[1]
|
||||
|
||||
// Authcodes should start with the custom prefix "pin_ac_" to make them identifiable as authcodes when seen by a user out of context.
|
||||
require.True(t, strings.HasPrefix(capturedAuthCode, "pin_ac_"), "token %q did not have expected prefix 'pin_ac_'", capturedAuthCode)
|
||||
|
||||
// fosite authcodes are in the format `data.signature`, so grab the signature part, which is the lookup key in the storage interface
|
||||
authcodeDataAndSignature := strings.Split(capturedAuthCode, ".")
|
||||
require.Len(t, authcodeDataAndSignature, 2)
|
||||
|
@ -1818,6 +1818,9 @@ func testSupervisorLogin(
|
||||
authcode := callback.URL.Query().Get("code")
|
||||
require.NotEmpty(t, authcode)
|
||||
|
||||
// Authcodes should start with the custom prefix "pin_ac_" to make them identifiable as authcodes when seen by a user out of context.
|
||||
require.True(t, strings.HasPrefix(authcode, "pin_ac_"), "token %q did not have expected prefix 'pin_ac_'", authcode)
|
||||
|
||||
// Call the token endpoint to get tokens.
|
||||
tokenResponse, err := downstreamOAuth2Config.Exchange(oidcHTTPClientContext, authcode, pkceParam.Verifier())
|
||||
require.NoError(t, err)
|
||||
@ -1973,8 +1976,12 @@ func verifyTokenResponse(
|
||||
require.NotZero(t, tokenResponse.Expiry)
|
||||
expectedAccessTokenLifetime := oidc.DefaultOIDCTimeoutsConfiguration().AccessTokenLifespan
|
||||
testutil.RequireTimeInDelta(t, time.Now().UTC().Add(expectedAccessTokenLifetime), tokenResponse.Expiry, time.Second*30)
|
||||
// Access tokens should start with the custom prefix "pin_at_" to make them identifiable as access tokens when seen by a user out of context.
|
||||
require.True(t, strings.HasPrefix(tokenResponse.AccessToken, "pin_at_"), "token %q did not have expected prefix 'pin_at_'", tokenResponse.AccessToken)
|
||||
|
||||
require.NotEmpty(t, tokenResponse.RefreshToken)
|
||||
// Refresh tokens should start with the custom prefix "pin_rt_" to make them identifiable as refresh tokens when seen by a user out of context.
|
||||
require.True(t, strings.HasPrefix(tokenResponse.RefreshToken, "pin_rt_"), "token %q did not have expected prefix 'pin_rt_'", tokenResponse.RefreshToken)
|
||||
}
|
||||
|
||||
func requestAuthorizationUsingBrowserAuthcodeFlow(t *testing.T, downstreamAuthorizeURL, downstreamCallbackURL, _, _ string, httpClient *http.Client) {
|
||||
|
Loading…
Reference in New Issue
Block a user