Refactor oidcclient.Login to use new upstreamoidc package.
Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
parent
4b60c922ef
commit
b272b3f331
@ -20,6 +20,10 @@ import (
|
|||||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func New(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
|
||||||
|
return &ProviderConfig{Config: config, Provider: provider}
|
||||||
|
}
|
||||||
|
|
||||||
// ProviderConfig holds the active configuration of an upstream OIDC provider.
|
// ProviderConfig holds the active configuration of an upstream OIDC provider.
|
||||||
type ProviderConfig struct {
|
type ProviderConfig struct {
|
||||||
Name string
|
Name string
|
||||||
@ -31,9 +35,6 @@ type ProviderConfig struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// *ProviderConfig should implement provider.UpstreamOIDCIdentityProviderI.
|
|
||||||
var _ provider.UpstreamOIDCIdentityProviderI = (*ProviderConfig)(nil)
|
|
||||||
|
|
||||||
func (p *ProviderConfig) GetName() string {
|
func (p *ProviderConfig) GetName() string {
|
||||||
return p.Name
|
return p.Name
|
||||||
}
|
}
|
||||||
|
@ -16,10 +16,11 @@ import (
|
|||||||
"github.com/coreos/go-oidc"
|
"github.com/coreos/go-oidc"
|
||||||
"github.com/pkg/browser"
|
"github.com/pkg/browser"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
|
||||||
|
|
||||||
"go.pinniped.dev/internal/httputil/httperr"
|
"go.pinniped.dev/internal/httputil/httperr"
|
||||||
"go.pinniped.dev/internal/httputil/securityheader"
|
"go.pinniped.dev/internal/httputil/securityheader"
|
||||||
|
"go.pinniped.dev/internal/oidc/provider"
|
||||||
|
"go.pinniped.dev/internal/upstreamoidc"
|
||||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||||
@ -52,18 +53,18 @@ type handlerState struct {
|
|||||||
callbackPath string
|
callbackPath string
|
||||||
|
|
||||||
// Generated parameters of a login flow.
|
// Generated parameters of a login flow.
|
||||||
idTokenVerifier *oidc.IDTokenVerifier
|
provider *oidc.Provider
|
||||||
oauth2Config *oauth2.Config
|
oauth2Config *oauth2.Config
|
||||||
state state.State
|
state state.State
|
||||||
nonce nonce.Nonce
|
nonce nonce.Nonce
|
||||||
pkce pkce.Code
|
pkce pkce.Code
|
||||||
|
|
||||||
// External calls for things.
|
// External calls for things.
|
||||||
generateState func() (state.State, error)
|
generateState func() (state.State, error)
|
||||||
generatePKCE func() (pkce.Code, error)
|
generatePKCE func() (pkce.Code, error)
|
||||||
generateNonce func() (nonce.Nonce, error)
|
generateNonce func() (nonce.Nonce, error)
|
||||||
openURL func(string) error
|
openURL func(string) error
|
||||||
oidcDiscover func(context.Context, string) (discoveryI, error)
|
getProvider func(*oauth2.Config, *oidc.Provider) provider.UpstreamOIDCIdentityProviderI
|
||||||
|
|
||||||
callbacks chan callbackResult
|
callbacks chan callbackResult
|
||||||
}
|
}
|
||||||
@ -152,11 +153,6 @@ type nopCache struct{}
|
|||||||
func (*nopCache) GetToken(SessionCacheKey) *oidctypes.Token { return nil }
|
func (*nopCache) GetToken(SessionCacheKey) *oidctypes.Token { return nil }
|
||||||
func (*nopCache) PutToken(SessionCacheKey, *oidctypes.Token) {}
|
func (*nopCache) PutToken(SessionCacheKey, *oidctypes.Token) {}
|
||||||
|
|
||||||
type discoveryI interface {
|
|
||||||
Endpoint() oauth2.Endpoint
|
|
||||||
Verifier(*oidc.Config) *oidc.IDTokenVerifier
|
|
||||||
}
|
|
||||||
|
|
||||||
// Login performs an OAuth2/OIDC authorization code login using a localhost listener.
|
// Login performs an OAuth2/OIDC authorization code login using a localhost listener.
|
||||||
func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, error) {
|
func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, error) {
|
||||||
h := handlerState{
|
h := handlerState{
|
||||||
@ -175,9 +171,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
|
|||||||
generateNonce: nonce.Generate,
|
generateNonce: nonce.Generate,
|
||||||
generatePKCE: pkce.Generate,
|
generatePKCE: pkce.Generate,
|
||||||
openURL: browser.OpenURL,
|
openURL: browser.OpenURL,
|
||||||
oidcDiscover: func(ctx context.Context, iss string) (discoveryI, error) {
|
getProvider: upstreamoidc.New,
|
||||||
return oidc.NewProvider(ctx, iss)
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
if err := opt(&h); err != nil {
|
if err := opt(&h); err != nil {
|
||||||
@ -222,16 +216,15 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Perform OIDC discovery.
|
// Perform OIDC discovery.
|
||||||
discovered, err := h.oidcDiscover(h.ctx, h.issuer)
|
h.provider, err = oidc.NewProvider(h.ctx, h.issuer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
|
return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
|
||||||
}
|
}
|
||||||
h.idTokenVerifier = discovered.Verifier(&oidc.Config{ClientID: h.clientID})
|
|
||||||
|
|
||||||
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
|
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
|
||||||
h.oauth2Config = &oauth2.Config{
|
h.oauth2Config = &oauth2.Config{
|
||||||
ClientID: h.clientID,
|
ClientID: h.clientID,
|
||||||
Endpoint: discovered.Endpoint(),
|
Endpoint: h.provider.Endpoint(),
|
||||||
Scopes: h.scopes,
|
Scopes: h.scopes,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -301,7 +294,11 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype
|
|||||||
|
|
||||||
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
|
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
|
||||||
// some providers do not include one, so we skip the nonce validation here (but not other validations).
|
// some providers do not include one, so we skip the nonce validation here (but not other validations).
|
||||||
return h.validateToken(ctx, refreshed, false)
|
token, _, err := h.getProvider(h.oauth2Config, h.provider).ValidateToken(ctx, refreshed, "")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {
|
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {
|
||||||
@ -328,58 +325,18 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
|
|||||||
return httperr.Newf(http.StatusBadRequest, "login failed with code %q", errorParam)
|
return httperr.Newf(http.StatusBadRequest, "login failed with code %q", errorParam)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exchange the authorization code for access, ID, and refresh tokens.
|
// Exchange the authorization code for access, ID, and refresh tokens and perform required
|
||||||
oauth2Tok, err := h.oauth2Config.Exchange(r.Context(), params.Get("code"), h.pkce.Verifier())
|
// validations on the returned ID token.
|
||||||
|
token, _, err := h.getProvider(h.oauth2Config, h.provider).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err)
|
return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform required validations on the returned ID token.
|
h.callbacks <- callbackResult{token: &token}
|
||||||
token, err := h.validateToken(r.Context(), oauth2Tok, true)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
h.callbacks <- callbackResult{token: token}
|
|
||||||
_, _ = w.Write([]byte("you have been logged in and may now close this tab"))
|
_, _ = w.Write([]byte("you have been logged in and may now close this tab"))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, checkNonce bool) (*oidctypes.Token, error) {
|
|
||||||
idTok, hasIDTok := tok.Extra("id_token").(string)
|
|
||||||
if !hasIDTok {
|
|
||||||
return nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
|
|
||||||
}
|
|
||||||
validated, err := h.idTokenVerifier.Verify(ctx, idTok)
|
|
||||||
if err != nil {
|
|
||||||
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
|
||||||
}
|
|
||||||
if validated.AccessTokenHash != "" {
|
|
||||||
if err := validated.VerifyAccessToken(tok.AccessToken); err != nil {
|
|
||||||
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if checkNonce {
|
|
||||||
if err := h.nonce.Validate(validated); err != nil {
|
|
||||||
return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &oidctypes.Token{
|
|
||||||
AccessToken: &oidctypes.AccessToken{
|
|
||||||
Token: tok.AccessToken,
|
|
||||||
Type: tok.TokenType,
|
|
||||||
Expiry: metav1.NewTime(tok.Expiry),
|
|
||||||
},
|
|
||||||
RefreshToken: &oidctypes.RefreshToken{
|
|
||||||
Token: tok.RefreshToken,
|
|
||||||
},
|
|
||||||
IDToken: &oidctypes.IDToken{
|
|
||||||
Token: idTok,
|
|
||||||
Expiry: metav1.NewTime(validated.Expiry),
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *handlerState) serve(listener net.Listener) func() {
|
func (h *handlerState) serve(listener net.Listener) func() {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))
|
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))
|
||||||
|
@ -18,11 +18,11 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"gopkg.in/square/go-jose.v2"
|
|
||||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
|
|
||||||
"go.pinniped.dev/internal/httputil/httperr"
|
"go.pinniped.dev/internal/httputil/httperr"
|
||||||
"go.pinniped.dev/internal/mocks/mockkeyset"
|
"go.pinniped.dev/internal/mocks/mockupstreamoidcidentityprovider"
|
||||||
|
"go.pinniped.dev/internal/oidc/provider"
|
||||||
"go.pinniped.dev/internal/testutil"
|
"go.pinniped.dev/internal/testutil"
|
||||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||||
@ -57,19 +57,9 @@ func TestLogin(t *testing.T) {
|
|||||||
require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix())
|
require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix())
|
||||||
|
|
||||||
testToken := oidctypes.Token{
|
testToken := oidctypes.Token{
|
||||||
AccessToken: &oidctypes.AccessToken{
|
AccessToken: &oidctypes.AccessToken{Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute))},
|
||||||
Token: "test-access-token",
|
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
|
||||||
Expiry: metav1.NewTime(time1.Add(1 * time.Minute)),
|
IDToken: &oidctypes.IDToken{Token: "test-id-token", Expiry: metav1.NewTime(time1.Add(2 * time.Minute))},
|
||||||
},
|
|
||||||
RefreshToken: &oidctypes.RefreshToken{
|
|
||||||
Token: "test-refresh-token",
|
|
||||||
},
|
|
||||||
IDToken: &oidctypes.IDToken{
|
|
||||||
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/ (using time1Unix from above):
|
|
||||||
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" --exp 2075807775
|
|
||||||
Token: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImV4cCI6MjA3NTgwNzc3NSwiaWF0IjoxNjAzMzk5NTY4LCJpc3MiOiJ0ZXN0LWlzc3VlciIsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAzMzk5NTY4LCJzdWIiOiJ0ZXN0LXVzZXIifQ.CdwUWQb6xELeFlC4u84K4rzks7YiDJiXxIo_SaRvCHBijxtil812RBRfPuAyYKJlGwFx1g-JYvkUg69X5NmvmLXkaOdHIKUAT7Nqa7yqd1xOAP9IlFj9qZM3Q7s8gWWW9da-_ryagzN4fyGfNfYeGhzIriSMaVpuBGz1eg6f-6VuuulnoiOpl8A0l50u0MdRjjsxRHuiR2loIhUxoIQQ9xN8w53UiP0R1uz8_uV0_K93RSq37aPjsnCXRLwUUb3azkRVe6B9EUW1ihthQ-KfRaU1iq2rY1m5UqNzf0NqDXCrN5SF-GVxOhKXJTsN4-PABfJBjqxg6dGUGeIa2JhFcA",
|
|
||||||
Expiry: metav1.NewTime(time1.Add(2 * time.Minute)),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a test server that returns 500 errors
|
// Start a test server that returns 500 errors
|
||||||
@ -78,7 +68,7 @@ func TestLogin(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
t.Cleanup(errorServer.Close)
|
t.Cleanup(errorServer.Close)
|
||||||
|
|
||||||
// Start a test server that returns a real keyset and answers refresh requests.
|
// Start a test server that returns a real discovery document and answers refresh requests.
|
||||||
providerMux := http.NewServeMux()
|
providerMux := http.NewServeMux()
|
||||||
successServer := httptest.NewServer(providerMux)
|
successServer := httptest.NewServer(providerMux)
|
||||||
t.Cleanup(successServer.Close)
|
t.Cleanup(successServer.Close)
|
||||||
@ -248,6 +238,14 @@ func TestLogin(t *testing.T) {
|
|||||||
clientID: "test-client-id",
|
clientID: "test-client-id",
|
||||||
opt: func(t *testing.T) Option {
|
opt: func(t *testing.T) Option {
|
||||||
return func(h *handlerState) error {
|
return func(h *handlerState) error {
|
||||||
|
h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
|
||||||
|
mock := mockUpstream(t)
|
||||||
|
mock.EXPECT().
|
||||||
|
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
|
||||||
|
Return(testToken, nil, nil)
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
|
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
|
||||||
IDToken: &oidctypes.IDToken{
|
IDToken: &oidctypes.IDToken{
|
||||||
Token: "expired-test-id-token",
|
Token: "expired-test-id-token",
|
||||||
@ -268,12 +266,6 @@ func TestLogin(t *testing.T) {
|
|||||||
require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token)
|
require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token)
|
||||||
})
|
})
|
||||||
h.cache = cache
|
h.cache = cache
|
||||||
|
|
||||||
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
|
|
||||||
provider, err := oidc.NewProvider(ctx, iss)
|
|
||||||
require.NoError(t, err)
|
|
||||||
return &mockDiscovery{provider: provider}, nil
|
|
||||||
}
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -285,6 +277,14 @@ func TestLogin(t *testing.T) {
|
|||||||
clientID: "test-client-id",
|
clientID: "test-client-id",
|
||||||
opt: func(t *testing.T) Option {
|
opt: func(t *testing.T) Option {
|
||||||
return func(h *handlerState) error {
|
return func(h *handlerState) error {
|
||||||
|
h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
|
||||||
|
mock := mockUpstream(t)
|
||||||
|
mock.EXPECT().
|
||||||
|
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
|
||||||
|
Return(oidctypes.Token{}, nil, fmt.Errorf("some validation error"))
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
|
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
|
||||||
IDToken: &oidctypes.IDToken{
|
IDToken: &oidctypes.IDToken{
|
||||||
Token: "expired-test-id-token",
|
Token: "expired-test-id-token",
|
||||||
@ -298,16 +298,10 @@ func TestLogin(t *testing.T) {
|
|||||||
})
|
})
|
||||||
h.cache = cache
|
h.cache = cache
|
||||||
|
|
||||||
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
|
|
||||||
provider, err := oidc.NewProvider(ctx, iss)
|
|
||||||
require.NoError(t, err)
|
|
||||||
return &mockDiscovery{provider: provider}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
|
wantErr: "some validation error",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "session cache hit but refresh fails",
|
name: "session cache hit but refresh fails",
|
||||||
@ -328,12 +322,6 @@ func TestLogin(t *testing.T) {
|
|||||||
})
|
})
|
||||||
h.cache = cache
|
h.cache = cache
|
||||||
|
|
||||||
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
|
|
||||||
provider, err := oidc.NewProvider(ctx, iss)
|
|
||||||
require.NoError(t, err)
|
|
||||||
return &mockDiscovery{provider: provider}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
h.listenAddr = "invalid-listen-address"
|
h.listenAddr = "invalid-listen-address"
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -504,7 +492,7 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
method string
|
method string
|
||||||
query string
|
query string
|
||||||
returnIDTok string
|
opt func(t *testing.T) Option
|
||||||
wantErr string
|
wantErr string
|
||||||
wantHTTPStatus int
|
wantHTTPStatus int
|
||||||
}{
|
}{
|
||||||
@ -530,94 +518,49 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "invalid code",
|
name: "invalid code",
|
||||||
query: "state=test-state&code=invalid",
|
query: "state=test-state&code=invalid",
|
||||||
wantErr: "could not complete code exchange: oauth2: cannot fetch token: 403 Forbidden\nResponse: invalid authorization code\n",
|
wantErr: "could not complete code exchange: some exchange error",
|
||||||
wantHTTPStatus: http.StatusBadRequest,
|
wantHTTPStatus: http.StatusBadRequest,
|
||||||
},
|
opt: func(t *testing.T) Option {
|
||||||
{
|
return func(h *handlerState) error {
|
||||||
name: "missing ID token",
|
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
|
||||||
query: "state=test-state&code=valid",
|
mock := mockUpstream(t)
|
||||||
returnIDTok: "",
|
mock.EXPECT().
|
||||||
wantErr: "received response missing ID token",
|
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")).
|
||||||
wantHTTPStatus: http.StatusBadRequest,
|
Return(oidctypes.Token{}, nil, fmt.Errorf("some exchange error"))
|
||||||
},
|
return mock
|
||||||
{
|
}
|
||||||
name: "invalid ID token",
|
return nil
|
||||||
query: "state=test-state&code=valid",
|
}
|
||||||
returnIDTok: "invalid-jwt",
|
},
|
||||||
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
|
|
||||||
wantHTTPStatus: http.StatusBadRequest,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid access token hash",
|
|
||||||
query: "state=test-state&code=valid",
|
|
||||||
|
|
||||||
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
|
|
||||||
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"at_hash": "invalid-at-hash"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
|
|
||||||
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdF9oYXNoIjoiaW52YWxpZC1hdC1oYXNoIiwiYXVkIjoidGVzdC1jbGllbnQtaWQiLCJpYXQiOjE2MDIyODM3OTEsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAyMjgzNzkxLCJzdWIiOiJ0ZXN0LXVzZXIifQ.jryXr4jiwcf79wBLaHpjdclEYHoUFGhvTu95QyA6Hnk9NQ0x1vsWYurtj7a8uKydNPryC_HNZi9QTAE_tRIJjycseog3695-5y4B4EZlqL-a94rdOtffuF2O_lnPbKvoja9EKNrp0kLBCftFRHhLAEwuP0N9E5padZwPpIGK0yE_JqljnYgCySvzsQu7tasR38yaULny13h3mtp2WRHPG5DrLyuBuF8Z01hSgRi5hGcVpgzTwBgV5-eMaSUCUo-ZDkqUsLQI6dVlaikCSKYZRb53HeexH0tB_R9PJJHY7mIr-rS76kkQEx9pLuVnheIH9Oc6zbdYWg-zWMijopA8Pg",
|
|
||||||
|
|
||||||
wantErr: "received invalid ID token: access token hash does not match value in ID token",
|
|
||||||
wantHTTPStatus: http.StatusBadRequest,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid nonce",
|
|
||||||
query: "state=test-state&code=valid",
|
|
||||||
|
|
||||||
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
|
|
||||||
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "invalid-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
|
|
||||||
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjI4Mzc0MSwianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDIyODM3NDEsIm5vbmNlIjoiaW52YWxpZC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.PRpq-7j5djaIAkraL-8t8ad9Xm4hM8RW67gyD1VIe0BecWeBFxsTuh3SZVKM9zmcwTgjudsyn8kQOwipDa49IN4PV8FcJA_uUJZi2wiqGJUSTG2K5I89doV_7e0RM1ZYIDDW1G2heKJNW7MbKkX7iEPr7u4MyEzswcPcupbyDA-CQFeL95vgwawoqa6yO94ympTbozqiNfj6Xyw_nHtThQnstjWsJZ9s2mUgppZezZv4HZYTQ7c3e_bzwhWgCzh2CSDJn9_Ra_n_4GcVkpHbsHTP35dFsnf0vactPx6CAu6A1-Apk-BruCktpZ3B4Ercf1UnUOHdGqzQKJtqvB03xQ",
|
|
||||||
|
|
||||||
wantHTTPStatus: http.StatusBadRequest,
|
|
||||||
wantErr: `received ID token with invalid nonce: invalid nonce (expected "test-nonce", got "invalid-nonce")`,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "valid",
|
name: "valid",
|
||||||
query: "state=test-state&code=valid",
|
query: "state=test-state&code=valid",
|
||||||
|
opt: func(t *testing.T) Option {
|
||||||
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
|
return func(h *handlerState) error {
|
||||||
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "test-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
|
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
|
||||||
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjUzMTU2NywianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDI1MzE1NjcsIm5vbmNlIjoidGVzdC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.LbOA31iwJZBM4ayY5Oud-HArLXbmtAIhZv_LazDqbzA2Iw87RxoBemfiPUJeAesdnO1LKSjBwbltZwtjvbLWHp1R5tqrSMr_hl2OyZv1cpEX-9QaTcQILJ5qR00riRLz34ZCQFyF-FfQpP1r4dNqFrxHuiBwKuPE7zogc83ZYJgAQM5Fao9rIRY9JStL_3pURa9JnnSHFlkLvFYv3TKEUyvnW4pWvYZcsGI7mys43vuSjpG7ZSrW3vCxovuIpXYqAhamZL_XexWUsXvi3ej9HNlhnhOFhN4fuPSc0PWDWaN0CLWmoo8gvOdQWo5A4GD4bNGBzjYOd-pYqsDfseRt1Q",
|
mock := mockUpstream(t)
|
||||||
|
mock.EXPECT().
|
||||||
|
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")).
|
||||||
|
Return(oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil, nil)
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
tt := tt
|
tt := tt
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
require.Equal(t, http.MethodPost, r.Method)
|
|
||||||
require.NoError(t, r.ParseForm())
|
|
||||||
require.Equal(t, "test-client-id", r.Form.Get("client_id"))
|
|
||||||
require.Equal(t, "test-pkce", r.Form.Get("code_verifier"))
|
|
||||||
require.Equal(t, "authorization_code", r.Form.Get("grant_type"))
|
|
||||||
require.NotEmpty(t, r.Form.Get("code"))
|
|
||||||
if r.Form.Get("code") != "valid" {
|
|
||||||
http.Error(w, "invalid authorization code", http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var response struct {
|
|
||||||
oauth2.Token
|
|
||||||
IDToken string `json:"id_token,omitempty"`
|
|
||||||
}
|
|
||||||
response.AccessToken = "test-access-token"
|
|
||||||
response.Expiry = time.Now().Add(time.Hour)
|
|
||||||
response.IDToken = tt.returnIDTok
|
|
||||||
w.Header().Set("content-type", "application/json")
|
|
||||||
require.NoError(t, json.NewEncoder(w).Encode(&response))
|
|
||||||
}))
|
|
||||||
t.Cleanup(tokenServer.Close)
|
|
||||||
|
|
||||||
h := &handlerState{
|
h := &handlerState{
|
||||||
callbacks: make(chan callbackResult, 1),
|
callbacks: make(chan callbackResult, 1),
|
||||||
state: state.State("test-state"),
|
state: state.State("test-state"),
|
||||||
pkce: pkce.Code("test-pkce"),
|
pkce: pkce.Code("test-pkce"),
|
||||||
nonce: nonce.Nonce("test-nonce"),
|
nonce: nonce.Nonce("test-nonce"),
|
||||||
oauth2Config: &oauth2.Config{
|
}
|
||||||
ClientID: "test-client-id",
|
if tt.opt != nil {
|
||||||
RedirectURL: "http://localhost:12345/callback",
|
require.NoError(t, tt.opt(t)(h))
|
||||||
Endpoint: oauth2.Endpoint{
|
|
||||||
TokenURL: tokenServer.URL,
|
|
||||||
AuthStyle: oauth2.AuthStyleInParams,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
idTokenVerifier: mockVerifier(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
@ -653,34 +596,34 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
require.NoError(t, result.err)
|
require.NoError(t, result.err)
|
||||||
require.NotNil(t, result.token)
|
require.NotNil(t, result.token)
|
||||||
require.Equal(t, result.token.IDToken.Token, tt.returnIDTok)
|
require.Equal(t, result.token.IDToken.Token, "test-id-token")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// mockVerifier returns an *oidc.IDTokenVerifier that validates any correctly serialized JWT without doing much else.
|
func mockUpstream(t *testing.T) *mockupstreamoidcidentityprovider.MockUpstreamOIDCIdentityProviderI {
|
||||||
func mockVerifier() *oidc.IDTokenVerifier {
|
t.Helper()
|
||||||
mockKeySet := mockkeyset.NewMockKeySet(gomock.NewController(nil))
|
ctrl := gomock.NewController(t)
|
||||||
mockKeySet.EXPECT().VerifySignature(gomock.Any(), gomock.Any()).
|
t.Cleanup(ctrl.Finish)
|
||||||
AnyTimes().
|
return mockupstreamoidcidentityprovider.NewMockUpstreamOIDCIdentityProviderI(ctrl)
|
||||||
DoAndReturn(func(ctx context.Context, jwt string) ([]byte, error) {
|
|
||||||
jws, err := jose.ParseSigned(jwt)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return jws.UnsafePayloadWithoutVerification(), nil
|
|
||||||
})
|
|
||||||
|
|
||||||
return oidc.NewVerifier("", mockKeySet, &oidc.Config{
|
|
||||||
SkipIssuerCheck: true,
|
|
||||||
SkipExpiryCheck: true,
|
|
||||||
SkipClientIDCheck: true,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockDiscovery struct{ provider *oidc.Provider }
|
// hasAccessTokenMatcher is a gomock.Matcher that expects an *oauth2.Token with a particular access token.
|
||||||
|
type hasAccessTokenMatcher struct{ expected string }
|
||||||
|
|
||||||
func (m *mockDiscovery) Endpoint() oauth2.Endpoint { return m.provider.Endpoint() }
|
func (m hasAccessTokenMatcher) Matches(arg interface{}) bool {
|
||||||
|
return arg.(*oauth2.Token).AccessToken == m.expected
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockDiscovery) Verifier(config *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() }
|
func (m hasAccessTokenMatcher) Got(got interface{}) string {
|
||||||
|
return got.(*oauth2.Token).AccessToken
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m hasAccessTokenMatcher) String() string {
|
||||||
|
return m.expected
|
||||||
|
}
|
||||||
|
|
||||||
|
func HasAccessToken(expected string) gomock.Matcher {
|
||||||
|
return hasAccessTokenMatcher{expected: expected}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user