Add custom prefix to downstream access and refresh tokens and authcodes

This commit is contained in:
Ryan Richard 2022-04-13 10:13:27 -07:00
parent 13daf59217
commit 53348b8464
5 changed files with 276 additions and 7 deletions

View File

@ -1,14 +1,22 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2022 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package oidc package oidc
import ( import (
"context" "context"
"strings"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/fosite/compose" "github.com/ory/fosite/compose"
"github.com/ory/fosite/handler/oauth2" "github.com/ory/fosite/handler/oauth2"
errorsx "github.com/pkg/errors"
)
const (
accessTokenPrefix = "pin_at_" // "Pinniped access token" abbreviated.
refreshTokenPrefix = "pin_rt_" // "Pinniped refresh token" abbreviated.
authcodePrefix = "pin_ac_" // "Pinniped authorization code" abbreviated.
) )
// dynamicOauth2HMACStrategy is an oauth2.CoreStrategy that can dynamically load an HMAC key to sign // dynamicOauth2HMACStrategy is an oauth2.CoreStrategy that can dynamically load an HMAC key to sign
@ -19,6 +27,9 @@ import (
// If we ever update FederationDomain's to hold their signing key, we might not need this type, since we // If we ever update FederationDomain's to hold their signing key, we might not need this type, since we
// could have an invariant that routes to an FederationDomain's endpoints are only wired up if an // could have an invariant that routes to an FederationDomain's endpoints are only wired up if an
// FederationDomain has a valid signing key. // FederationDomain has a valid signing key.
//
// Tokens start with a custom prefix to make them identifiable as tokens when seen by a user
// out of context, such as when accidentally committed to a GitHub repo.
type dynamicOauth2HMACStrategy struct { type dynamicOauth2HMACStrategy struct {
fositeConfig *compose.Config fositeConfig *compose.Config
keyFunc func() []byte keyFunc func() []byte
@ -44,7 +55,11 @@ func (s *dynamicOauth2HMACStrategy) GenerateAccessToken(
ctx context.Context, ctx context.Context,
requester fosite.Requester, requester fosite.Requester,
) (token string, signature string, err error) { ) (token string, signature string, err error) {
return s.delegate().GenerateAccessToken(ctx, requester) token, sig, err := s.delegate().GenerateAccessToken(ctx, requester)
if err == nil {
token = accessTokenPrefix + token
}
return token, sig, err
} }
func (s *dynamicOauth2HMACStrategy) ValidateAccessToken( func (s *dynamicOauth2HMACStrategy) ValidateAccessToken(
@ -52,7 +67,11 @@ func (s *dynamicOauth2HMACStrategy) ValidateAccessToken(
requester fosite.Requester, requester fosite.Requester,
token string, token string,
) (err error) { ) (err error) {
return s.delegate().ValidateAccessToken(ctx, requester, token) if !strings.HasPrefix(token, accessTokenPrefix) {
return errorsx.WithStack(fosite.ErrInvalidTokenFormat.
WithDebugf("Access token did not have prefix %q", accessTokenPrefix))
}
return s.delegate().ValidateAccessToken(ctx, requester, token[len(accessTokenPrefix):])
} }
func (s *dynamicOauth2HMACStrategy) RefreshTokenSignature(token string) string { func (s *dynamicOauth2HMACStrategy) RefreshTokenSignature(token string) string {
@ -63,7 +82,11 @@ func (s *dynamicOauth2HMACStrategy) GenerateRefreshToken(
ctx context.Context, ctx context.Context,
requester fosite.Requester, requester fosite.Requester,
) (token string, signature string, err error) { ) (token string, signature string, err error) {
return s.delegate().GenerateRefreshToken(ctx, requester) token, sig, err := s.delegate().GenerateRefreshToken(ctx, requester)
if err == nil {
token = refreshTokenPrefix + token
}
return token, sig, err
} }
func (s *dynamicOauth2HMACStrategy) ValidateRefreshToken( func (s *dynamicOauth2HMACStrategy) ValidateRefreshToken(
@ -71,7 +94,11 @@ func (s *dynamicOauth2HMACStrategy) ValidateRefreshToken(
requester fosite.Requester, requester fosite.Requester,
token string, token string,
) (err error) { ) (err error) {
return s.delegate().ValidateRefreshToken(ctx, requester, token) if !strings.HasPrefix(token, refreshTokenPrefix) {
return errorsx.WithStack(fosite.ErrInvalidTokenFormat.
WithDebugf("Refresh token did not have prefix %q", refreshTokenPrefix))
}
return s.delegate().ValidateRefreshToken(ctx, requester, token[len(refreshTokenPrefix):])
} }
func (s *dynamicOauth2HMACStrategy) AuthorizeCodeSignature(token string) string { func (s *dynamicOauth2HMACStrategy) AuthorizeCodeSignature(token string) string {
@ -82,7 +109,11 @@ func (s *dynamicOauth2HMACStrategy) GenerateAuthorizeCode(
ctx context.Context, ctx context.Context,
requester fosite.Requester, requester fosite.Requester,
) (token string, signature string, err error) { ) (token string, signature string, err error) {
return s.delegate().GenerateAuthorizeCode(ctx, requester) authcode, sig, err := s.delegate().GenerateAuthorizeCode(ctx, requester)
if err == nil {
authcode = authcodePrefix + authcode
}
return authcode, sig, err
} }
func (s *dynamicOauth2HMACStrategy) ValidateAuthorizeCode( func (s *dynamicOauth2HMACStrategy) ValidateAuthorizeCode(
@ -90,7 +121,11 @@ func (s *dynamicOauth2HMACStrategy) ValidateAuthorizeCode(
requester fosite.Requester, requester fosite.Requester,
token string, token string,
) (err error) { ) (err error) {
return s.delegate().ValidateAuthorizeCode(ctx, requester, token) if !strings.HasPrefix(token, authcodePrefix) {
return errorsx.WithStack(fosite.ErrInvalidTokenFormat.
WithDebugf("Authorization code did not have prefix %q", authcodePrefix))
}
return s.delegate().ValidateAuthorizeCode(ctx, requester, token[len(authcodePrefix):])
} }
func (s *dynamicOauth2HMACStrategy) delegate() *oauth2.HMACSHAStrategy { func (s *dynamicOauth2HMACStrategy) delegate() *oauth2.HMACSHAStrategy {

View File

@ -0,0 +1,218 @@
// Copyright 2022 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package oidc
import (
"context"
"encoding/base64"
"strings"
"testing"
"time"
"github.com/ory/fosite"
"github.com/ory/fosite/compose"
"github.com/stretchr/testify/require"
)
func TestDynamicOauth2HMACStrategy_Signatures(t *testing.T) {
s := &dynamicOauth2HMACStrategy{
fositeConfig: &compose.Config{}, // defaults are good enough for this unit test
keyFunc: func() []byte { return []byte("12345678901234567890123456789012") },
}
tests := []struct {
name string
token string
signatureFunc func(token string) (signature string)
wantSignature string
}{
{
name: "access token signature is the part after the dot in the default HMAC strategy",
token: "token.signature",
signatureFunc: s.AccessTokenSignature,
wantSignature: "signature",
},
{
name: "refresh token signature is the part after the dot in the default HMAC strategy",
token: "token.signature",
signatureFunc: s.RefreshTokenSignature,
wantSignature: "signature",
},
{
name: "authcode signature is the part after the dot in the default HMAC strategy",
token: "token.signature",
signatureFunc: s.AuthorizeCodeSignature,
wantSignature: "signature",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.wantSignature, tt.signatureFunc(tt.token))
})
}
}
func TestDynamicOauth2HMACStrategy_Generate(t *testing.T) {
s := &dynamicOauth2HMACStrategy{
fositeConfig: &compose.Config{}, // defaults are good enough for this unit test
keyFunc: func() []byte { return []byte("12345678901234567890123456789012") }, // 32 character secret key
}
generateTokenErrorCausingStrategy := &dynamicOauth2HMACStrategy{
fositeConfig: &compose.Config{},
keyFunc: func() []byte { return []byte("too_short_causes_error") }, // secret key is below required 32 characters
}
tests := []struct {
name string
generateFunc func(ctx context.Context, requester fosite.Requester) (token string, signature string, err error)
errGenerateFunc func(ctx context.Context, requester fosite.Requester) (token string, signature string, err error)
wantPrefix string
}{
{
name: "access tokens are base64 random bytes followed by dot followed by base64 signature of the random bytes in the default HMAC strategy",
generateFunc: s.GenerateAccessToken,
errGenerateFunc: generateTokenErrorCausingStrategy.GenerateAccessToken,
wantPrefix: "pin_at_",
},
{
name: "refresh tokens are base64 random bytes followed by dot followed by base64 signature of the random bytes in the default HMAC strategy",
generateFunc: s.GenerateRefreshToken,
errGenerateFunc: generateTokenErrorCausingStrategy.GenerateRefreshToken,
wantPrefix: "pin_rt_",
},
{
name: "authcodes are base64 random bytes followed by dot followed by base64 signature of the random bytes in the default HMAC strategy",
generateFunc: s.GenerateAuthorizeCode,
errGenerateFunc: generateTokenErrorCausingStrategy.GenerateAuthorizeCode,
wantPrefix: "pin_ac_",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
requireRandomTokenInExpectedFormat := func(token, signature string) {
// Tokens should start with a custom prefix to make them identifiable as tokens when seen by a user
// out of context, such as when accidentally committed to a GitHub repo.
require.True(t, strings.HasPrefix(token, tt.wantPrefix), "token %q did not have expected prefix %q", token, tt.wantPrefix)
require.Equal(t, 1, strings.Count(token, "."))
require.Len(t, signature, 43)
require.True(t, strings.HasSuffix(token, "."+signature), "token %q did not end with dot followed by signature", token)
// The part before the dot is the prefix plus 43 characters of base64 encoded random bytes.
require.Len(t, strings.Split(token, ".")[0], len(tt.wantPrefix)+43)
}
var ctxIsIgnored context.Context
var requesterIsIgnored fosite.Requester
generatedToken1, signature1, err := tt.generateFunc(ctxIsIgnored, requesterIsIgnored)
require.NoError(t, err)
requireRandomTokenInExpectedFormat(generatedToken1, signature1)
generatedToken2, signature2, err := tt.generateFunc(ctxIsIgnored, requesterIsIgnored)
require.NoError(t, err)
requireRandomTokenInExpectedFormat(generatedToken2, signature2)
// Each generated token is random/different.
require.NotEqual(t, generatedToken1, generatedToken2)
require.NotEqual(t, signature1, signature2)
// Test the return values when an error is encountered during generation.
generatedToken3, signature3, err := tt.errGenerateFunc(ctxIsIgnored, requesterIsIgnored)
require.EqualError(t, err, "secret for signing HMAC-SHA512/256 is expected to be 32 byte long, got 22 byte")
require.Empty(t, generatedToken3)
require.Empty(t, signature3)
})
}
}
func TestDynamicOauth2HMACStrategy_Validate(t *testing.T) {
s := &dynamicOauth2HMACStrategy{
fositeConfig: &compose.Config{}, // defaults are good enough for this unit test
keyFunc: func() []byte { return []byte("12345678901234567890123456789012") },
}
tests := []struct {
name string
generateFunc func(ctx context.Context, requester fosite.Requester) (token string, signature string, err error)
validateFunc func(ctx context.Context, requester fosite.Requester, token string) error
wantPrefix string
}{
{
name: "access tokens",
generateFunc: s.GenerateAccessToken,
validateFunc: s.ValidateAccessToken,
wantPrefix: "pin_at_",
},
{
name: "refresh tokens",
generateFunc: s.GenerateRefreshToken,
validateFunc: s.ValidateRefreshToken,
wantPrefix: "pin_rt_",
},
{
name: "authcodes",
generateFunc: s.GenerateAuthorizeCode,
validateFunc: s.ValidateAuthorizeCode,
wantPrefix: "pin_ac_",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
var ctxIsIgnored context.Context
var requesterIsIgnored fosite.Requester
unexpiredSession := &fosite.DefaultSession{}
unexpiredSession.SetExpiresAt(fosite.RefreshToken, time.Now().Add(time.Hour))
unexpiredSession.SetExpiresAt(fosite.AccessToken, time.Now().Add(time.Hour))
unexpiredSession.SetExpiresAt(fosite.AuthorizeCode, time.Now().Add(time.Hour))
requesterWithUnexpiredTokens := &fosite.Request{Session: unexpiredSession}
expiredSession := &fosite.DefaultSession{}
expiredSession.SetExpiresAt(fosite.RefreshToken, time.Now().Add(-time.Hour))
expiredSession.SetExpiresAt(fosite.AccessToken, time.Now().Add(-time.Hour))
expiredSession.SetExpiresAt(fosite.AuthorizeCode, time.Now().Add(-time.Hour))
requesterWithExpiredTokens := &fosite.Request{Session: expiredSession}
generatedToken, _, err := tt.generateFunc(ctxIsIgnored, requesterIsIgnored)
require.NoError(t, err)
require.NoError(t, tt.validateFunc(ctxIsIgnored, requesterWithUnexpiredTokens, generatedToken))
generatedToken, _, err = tt.generateFunc(ctxIsIgnored, requesterIsIgnored)
require.NoError(t, err)
require.NoError(t, tt.validateFunc(ctxIsIgnored, requesterWithUnexpiredTokens, generatedToken))
// Generated token has prefix.
require.True(t, strings.HasPrefix(generatedToken, tt.wantPrefix), "token %q did not have expected prefix %q", generatedToken, tt.wantPrefix)
// Validate when expired according to session.
require.EqualError(t, tt.validateFunc(ctxIsIgnored, requesterWithExpiredTokens, generatedToken), "invalid_token")
// Validate when missing prefix.
require.EqualError(t, tt.validateFunc(ctxIsIgnored, requesterWithUnexpiredTokens, strings.TrimPrefix(generatedToken, tt.wantPrefix)), "invalid_token")
// Validate when wrong prefix.
require.EqualError(t, tt.validateFunc(ctxIsIgnored, requesterWithUnexpiredTokens, "pin_wrong_"+strings.TrimPrefix(generatedToken, tt.wantPrefix)), "invalid_token")
// Validate when correct prefix but otherwise invalid format.
require.EqualError(t, tt.validateFunc(ctxIsIgnored, requesterWithUnexpiredTokens, tt.wantPrefix+"illegal token"), "invalid_token")
// Validate when correct prefix but bad signature.
var b64 = base64.URLEncoding.WithPadding(base64.NoPadding)
tokenWithBadSig := tt.wantPrefix + b64.EncodeToString([]byte("some-token")) + "." + b64.EncodeToString([]byte("bad-signature"))
require.EqualError(t, tt.validateFunc(ctxIsIgnored, requesterWithUnexpiredTokens, tokenWithBadSig), "token_signature_mismatch")
})
}
}

View File

@ -3251,6 +3251,9 @@ func requireValidRefreshTokenStorage(
storedRequest, err := storage.GetRefreshTokenSession(context.Background(), getFositeDataSignature(t, refreshTokenString), nil) storedRequest, err := storage.GetRefreshTokenSession(context.Background(), getFositeDataSignature(t, refreshTokenString), nil)
require.NoError(t, err) require.NoError(t, err)
// Refresh tokens should start with the custom prefix "pin_rt_" to make them identifiable as refresh tokens when seen by a user out of context.
require.True(t, strings.HasPrefix(refreshTokenString, "pin_rt_"), "token %q did not have expected prefix 'pin_rt_'", refreshTokenString)
// Fosite stores refresh tokens without any of the original request form parameters. // Fosite stores refresh tokens without any of the original request form parameters.
requireValidStoredRequest( requireValidStoredRequest(
t, t,
@ -3287,6 +3290,9 @@ func requireValidAccessTokenStorage(
storedRequest, err := storage.GetAccessTokenSession(context.Background(), getFositeDataSignature(t, accessTokenString), nil) storedRequest, err := storage.GetAccessTokenSession(context.Background(), getFositeDataSignature(t, accessTokenString), nil)
require.NoError(t, err) require.NoError(t, err)
// Access tokens should start with the custom prefix "pin_at_" to make them identifiable as access tokens when seen by a user out of context.
require.True(t, strings.HasPrefix(accessTokenString, "pin_at_"), "token %q did not have expected prefix 'pin_at_'", accessTokenString)
// Make sure the other body fields are valid. // Make sure the other body fields are valid.
tokenType, ok := body["token_type"] tokenType, ok := body["token_type"]
require.True(t, ok) require.True(t, ok)

View File

@ -901,6 +901,9 @@ func RequireAuthCodeRegexpMatch(
require.Lenf(t, submatches, 2, "no regexp match in actualContent: %", actualContent) require.Lenf(t, submatches, 2, "no regexp match in actualContent: %", actualContent)
capturedAuthCode := submatches[1] capturedAuthCode := submatches[1]
// Authcodes should start with the custom prefix "pin_ac_" to make them identifiable as authcodes when seen by a user out of context.
require.True(t, strings.HasPrefix(capturedAuthCode, "pin_ac_"), "token %q did not have expected prefix 'pin_ac_'", capturedAuthCode)
// fosite authcodes are in the format `data.signature`, so grab the signature part, which is the lookup key in the storage interface // fosite authcodes are in the format `data.signature`, so grab the signature part, which is the lookup key in the storage interface
authcodeDataAndSignature := strings.Split(capturedAuthCode, ".") authcodeDataAndSignature := strings.Split(capturedAuthCode, ".")
require.Len(t, authcodeDataAndSignature, 2) require.Len(t, authcodeDataAndSignature, 2)

View File

@ -1818,6 +1818,9 @@ func testSupervisorLogin(
authcode := callback.URL.Query().Get("code") authcode := callback.URL.Query().Get("code")
require.NotEmpty(t, authcode) require.NotEmpty(t, authcode)
// Authcodes should start with the custom prefix "pin_ac_" to make them identifiable as authcodes when seen by a user out of context.
require.True(t, strings.HasPrefix(authcode, "pin_ac_"), "token %q did not have expected prefix 'pin_ac_'", authcode)
// Call the token endpoint to get tokens. // Call the token endpoint to get tokens.
tokenResponse, err := downstreamOAuth2Config.Exchange(oidcHTTPClientContext, authcode, pkceParam.Verifier()) tokenResponse, err := downstreamOAuth2Config.Exchange(oidcHTTPClientContext, authcode, pkceParam.Verifier())
require.NoError(t, err) require.NoError(t, err)
@ -1973,8 +1976,12 @@ func verifyTokenResponse(
require.NotZero(t, tokenResponse.Expiry) require.NotZero(t, tokenResponse.Expiry)
expectedAccessTokenLifetime := oidc.DefaultOIDCTimeoutsConfiguration().AccessTokenLifespan expectedAccessTokenLifetime := oidc.DefaultOIDCTimeoutsConfiguration().AccessTokenLifespan
testutil.RequireTimeInDelta(t, time.Now().UTC().Add(expectedAccessTokenLifetime), tokenResponse.Expiry, time.Second*30) testutil.RequireTimeInDelta(t, time.Now().UTC().Add(expectedAccessTokenLifetime), tokenResponse.Expiry, time.Second*30)
// Access tokens should start with the custom prefix "pin_at_" to make them identifiable as access tokens when seen by a user out of context.
require.True(t, strings.HasPrefix(tokenResponse.AccessToken, "pin_at_"), "token %q did not have expected prefix 'pin_at_'", tokenResponse.AccessToken)
require.NotEmpty(t, tokenResponse.RefreshToken) require.NotEmpty(t, tokenResponse.RefreshToken)
// Refresh tokens should start with the custom prefix "pin_rt_" to make them identifiable as refresh tokens when seen by a user out of context.
require.True(t, strings.HasPrefix(tokenResponse.RefreshToken, "pin_rt_"), "token %q did not have expected prefix 'pin_rt_'", tokenResponse.RefreshToken)
} }
func requestAuthorizationUsingBrowserAuthcodeFlow(t *testing.T, downstreamAuthorizeURL, downstreamCallbackURL, _, _ string, httpClient *http.Client) { func requestAuthorizationUsingBrowserAuthcodeFlow(t *testing.T, downstreamAuthorizeURL, downstreamCallbackURL, _, _ string, httpClient *http.Client) {