Refactor oidcclient.Login to use new upstreamoidc package.
Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
parent
4b60c922ef
commit
b272b3f331
@ -20,6 +20,10 @@ import (
|
||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||
)
|
||||
|
||||
func New(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
|
||||
return &ProviderConfig{Config: config, Provider: provider}
|
||||
}
|
||||
|
||||
// ProviderConfig holds the active configuration of an upstream OIDC provider.
|
||||
type ProviderConfig struct {
|
||||
Name string
|
||||
@ -31,9 +35,6 @@ type ProviderConfig struct {
|
||||
}
|
||||
}
|
||||
|
||||
// *ProviderConfig should implement provider.UpstreamOIDCIdentityProviderI.
|
||||
var _ provider.UpstreamOIDCIdentityProviderI = (*ProviderConfig)(nil)
|
||||
|
||||
func (p *ProviderConfig) GetName() string {
|
||||
return p.Name
|
||||
}
|
||||
|
@ -16,10 +16,11 @@ import (
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/pkg/browser"
|
||||
"golang.org/x/oauth2"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
|
||||
"go.pinniped.dev/internal/httputil/httperr"
|
||||
"go.pinniped.dev/internal/httputil/securityheader"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/internal/upstreamoidc"
|
||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
"go.pinniped.dev/pkg/oidcclient/pkce"
|
||||
@ -52,18 +53,18 @@ type handlerState struct {
|
||||
callbackPath string
|
||||
|
||||
// Generated parameters of a login flow.
|
||||
idTokenVerifier *oidc.IDTokenVerifier
|
||||
oauth2Config *oauth2.Config
|
||||
state state.State
|
||||
nonce nonce.Nonce
|
||||
pkce pkce.Code
|
||||
provider *oidc.Provider
|
||||
oauth2Config *oauth2.Config
|
||||
state state.State
|
||||
nonce nonce.Nonce
|
||||
pkce pkce.Code
|
||||
|
||||
// External calls for things.
|
||||
generateState func() (state.State, error)
|
||||
generatePKCE func() (pkce.Code, error)
|
||||
generateNonce func() (nonce.Nonce, error)
|
||||
openURL func(string) error
|
||||
oidcDiscover func(context.Context, string) (discoveryI, error)
|
||||
getProvider func(*oauth2.Config, *oidc.Provider) provider.UpstreamOIDCIdentityProviderI
|
||||
|
||||
callbacks chan callbackResult
|
||||
}
|
||||
@ -152,11 +153,6 @@ type nopCache struct{}
|
||||
func (*nopCache) GetToken(SessionCacheKey) *oidctypes.Token { return nil }
|
||||
func (*nopCache) PutToken(SessionCacheKey, *oidctypes.Token) {}
|
||||
|
||||
type discoveryI interface {
|
||||
Endpoint() oauth2.Endpoint
|
||||
Verifier(*oidc.Config) *oidc.IDTokenVerifier
|
||||
}
|
||||
|
||||
// Login performs an OAuth2/OIDC authorization code login using a localhost listener.
|
||||
func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, error) {
|
||||
h := handlerState{
|
||||
@ -175,9 +171,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
|
||||
generateNonce: nonce.Generate,
|
||||
generatePKCE: pkce.Generate,
|
||||
openURL: browser.OpenURL,
|
||||
oidcDiscover: func(ctx context.Context, iss string) (discoveryI, error) {
|
||||
return oidc.NewProvider(ctx, iss)
|
||||
},
|
||||
getProvider: upstreamoidc.New,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
if err := opt(&h); err != nil {
|
||||
@ -222,16 +216,15 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
|
||||
}
|
||||
|
||||
// Perform OIDC discovery.
|
||||
discovered, err := h.oidcDiscover(h.ctx, h.issuer)
|
||||
h.provider, err = oidc.NewProvider(h.ctx, h.issuer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
|
||||
}
|
||||
h.idTokenVerifier = discovered.Verifier(&oidc.Config{ClientID: h.clientID})
|
||||
|
||||
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
|
||||
h.oauth2Config = &oauth2.Config{
|
||||
ClientID: h.clientID,
|
||||
Endpoint: discovered.Endpoint(),
|
||||
Endpoint: h.provider.Endpoint(),
|
||||
Scopes: h.scopes,
|
||||
}
|
||||
|
||||
@ -301,7 +294,11 @@ func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctype
|
||||
|
||||
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
|
||||
// some providers do not include one, so we skip the nonce validation here (but not other validations).
|
||||
return h.validateToken(ctx, refreshed, false)
|
||||
token, _, err := h.getProvider(h.oauth2Config, h.provider).ValidateToken(ctx, refreshed, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &token, nil
|
||||
}
|
||||
|
||||
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {
|
||||
@ -328,58 +325,18 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
|
||||
return httperr.Newf(http.StatusBadRequest, "login failed with code %q", errorParam)
|
||||
}
|
||||
|
||||
// Exchange the authorization code for access, ID, and refresh tokens.
|
||||
oauth2Tok, err := h.oauth2Config.Exchange(r.Context(), params.Get("code"), h.pkce.Verifier())
|
||||
// Exchange the authorization code for access, ID, and refresh tokens and perform required
|
||||
// validations on the returned ID token.
|
||||
token, _, err := h.getProvider(h.oauth2Config, h.provider).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce)
|
||||
if err != nil {
|
||||
return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err)
|
||||
}
|
||||
|
||||
// Perform required validations on the returned ID token.
|
||||
token, err := h.validateToken(r.Context(), oauth2Tok, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.callbacks <- callbackResult{token: token}
|
||||
h.callbacks <- callbackResult{token: &token}
|
||||
_, _ = w.Write([]byte("you have been logged in and may now close this tab"))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, checkNonce bool) (*oidctypes.Token, error) {
|
||||
idTok, hasIDTok := tok.Extra("id_token").(string)
|
||||
if !hasIDTok {
|
||||
return nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
|
||||
}
|
||||
validated, err := h.idTokenVerifier.Verify(ctx, idTok)
|
||||
if err != nil {
|
||||
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||
}
|
||||
if validated.AccessTokenHash != "" {
|
||||
if err := validated.VerifyAccessToken(tok.AccessToken); err != nil {
|
||||
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||
}
|
||||
}
|
||||
if checkNonce {
|
||||
if err := h.nonce.Validate(validated); err != nil {
|
||||
return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
|
||||
}
|
||||
}
|
||||
return &oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: tok.AccessToken,
|
||||
Type: tok.TokenType,
|
||||
Expiry: metav1.NewTime(tok.Expiry),
|
||||
},
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: tok.RefreshToken,
|
||||
},
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: idTok,
|
||||
Expiry: metav1.NewTime(validated.Expiry),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *handlerState) serve(listener net.Listener) func() {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))
|
||||
|
@ -18,11 +18,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||
|
||||
"go.pinniped.dev/internal/httputil/httperr"
|
||||
"go.pinniped.dev/internal/mocks/mockkeyset"
|
||||
"go.pinniped.dev/internal/mocks/mockupstreamoidcidentityprovider"
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/internal/testutil"
|
||||
"go.pinniped.dev/pkg/oidcclient/nonce"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
@ -57,19 +57,9 @@ func TestLogin(t *testing.T) {
|
||||
require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix())
|
||||
|
||||
testToken := oidctypes.Token{
|
||||
AccessToken: &oidctypes.AccessToken{
|
||||
Token: "test-access-token",
|
||||
Expiry: metav1.NewTime(time1.Add(1 * time.Minute)),
|
||||
},
|
||||
RefreshToken: &oidctypes.RefreshToken{
|
||||
Token: "test-refresh-token",
|
||||
},
|
||||
IDToken: &oidctypes.IDToken{
|
||||
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/ (using time1Unix from above):
|
||||
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" --exp 2075807775
|
||||
Token: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImV4cCI6MjA3NTgwNzc3NSwiaWF0IjoxNjAzMzk5NTY4LCJpc3MiOiJ0ZXN0LWlzc3VlciIsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAzMzk5NTY4LCJzdWIiOiJ0ZXN0LXVzZXIifQ.CdwUWQb6xELeFlC4u84K4rzks7YiDJiXxIo_SaRvCHBijxtil812RBRfPuAyYKJlGwFx1g-JYvkUg69X5NmvmLXkaOdHIKUAT7Nqa7yqd1xOAP9IlFj9qZM3Q7s8gWWW9da-_ryagzN4fyGfNfYeGhzIriSMaVpuBGz1eg6f-6VuuulnoiOpl8A0l50u0MdRjjsxRHuiR2loIhUxoIQQ9xN8w53UiP0R1uz8_uV0_K93RSq37aPjsnCXRLwUUb3azkRVe6B9EUW1ihthQ-KfRaU1iq2rY1m5UqNzf0NqDXCrN5SF-GVxOhKXJTsN4-PABfJBjqxg6dGUGeIa2JhFcA",
|
||||
Expiry: metav1.NewTime(time1.Add(2 * time.Minute)),
|
||||
},
|
||||
AccessToken: &oidctypes.AccessToken{Token: "test-access-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute))},
|
||||
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
|
||||
IDToken: &oidctypes.IDToken{Token: "test-id-token", Expiry: metav1.NewTime(time1.Add(2 * time.Minute))},
|
||||
}
|
||||
|
||||
// Start a test server that returns 500 errors
|
||||
@ -78,7 +68,7 @@ func TestLogin(t *testing.T) {
|
||||
}))
|
||||
t.Cleanup(errorServer.Close)
|
||||
|
||||
// Start a test server that returns a real keyset and answers refresh requests.
|
||||
// Start a test server that returns a real discovery document and answers refresh requests.
|
||||
providerMux := http.NewServeMux()
|
||||
successServer := httptest.NewServer(providerMux)
|
||||
t.Cleanup(successServer.Close)
|
||||
@ -248,6 +238,14 @@ func TestLogin(t *testing.T) {
|
||||
clientID: "test-client-id",
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
|
||||
mock := mockUpstream(t)
|
||||
mock.EXPECT().
|
||||
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
|
||||
Return(testToken, nil, nil)
|
||||
return mock
|
||||
}
|
||||
|
||||
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "expired-test-id-token",
|
||||
@ -268,12 +266,6 @@ func TestLogin(t *testing.T) {
|
||||
require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token)
|
||||
})
|
||||
h.cache = cache
|
||||
|
||||
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
|
||||
provider, err := oidc.NewProvider(ctx, iss)
|
||||
require.NoError(t, err)
|
||||
return &mockDiscovery{provider: provider}, nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
},
|
||||
@ -285,6 +277,14 @@ func TestLogin(t *testing.T) {
|
||||
clientID: "test-client-id",
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
h.getProvider = func(config *oauth2.Config, o *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
|
||||
mock := mockUpstream(t)
|
||||
mock.EXPECT().
|
||||
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
|
||||
Return(oidctypes.Token{}, nil, fmt.Errorf("some validation error"))
|
||||
return mock
|
||||
}
|
||||
|
||||
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Token: "expired-test-id-token",
|
||||
@ -298,16 +298,10 @@ func TestLogin(t *testing.T) {
|
||||
})
|
||||
h.cache = cache
|
||||
|
||||
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
|
||||
provider, err := oidc.NewProvider(ctx, iss)
|
||||
require.NoError(t, err)
|
||||
return &mockDiscovery{provider: provider}, nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
},
|
||||
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
|
||||
wantErr: "some validation error",
|
||||
},
|
||||
{
|
||||
name: "session cache hit but refresh fails",
|
||||
@ -328,12 +322,6 @@ func TestLogin(t *testing.T) {
|
||||
})
|
||||
h.cache = cache
|
||||
|
||||
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
|
||||
provider, err := oidc.NewProvider(ctx, iss)
|
||||
require.NoError(t, err)
|
||||
return &mockDiscovery{provider: provider}, nil
|
||||
}
|
||||
|
||||
h.listenAddr = "invalid-listen-address"
|
||||
|
||||
return nil
|
||||
@ -504,7 +492,7 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
||||
name string
|
||||
method string
|
||||
query string
|
||||
returnIDTok string
|
||||
opt func(t *testing.T) Option
|
||||
wantErr string
|
||||
wantHTTPStatus int
|
||||
}{
|
||||
@ -530,94 +518,49 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
||||
{
|
||||
name: "invalid code",
|
||||
query: "state=test-state&code=invalid",
|
||||
wantErr: "could not complete code exchange: oauth2: cannot fetch token: 403 Forbidden\nResponse: invalid authorization code\n",
|
||||
wantErr: "could not complete code exchange: some exchange error",
|
||||
wantHTTPStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "missing ID token",
|
||||
query: "state=test-state&code=valid",
|
||||
returnIDTok: "",
|
||||
wantErr: "received response missing ID token",
|
||||
wantHTTPStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invalid ID token",
|
||||
query: "state=test-state&code=valid",
|
||||
returnIDTok: "invalid-jwt",
|
||||
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
|
||||
wantHTTPStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invalid access token hash",
|
||||
query: "state=test-state&code=valid",
|
||||
|
||||
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
|
||||
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"at_hash": "invalid-at-hash"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
|
||||
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdF9oYXNoIjoiaW52YWxpZC1hdC1oYXNoIiwiYXVkIjoidGVzdC1jbGllbnQtaWQiLCJpYXQiOjE2MDIyODM3OTEsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAyMjgzNzkxLCJzdWIiOiJ0ZXN0LXVzZXIifQ.jryXr4jiwcf79wBLaHpjdclEYHoUFGhvTu95QyA6Hnk9NQ0x1vsWYurtj7a8uKydNPryC_HNZi9QTAE_tRIJjycseog3695-5y4B4EZlqL-a94rdOtffuF2O_lnPbKvoja9EKNrp0kLBCftFRHhLAEwuP0N9E5padZwPpIGK0yE_JqljnYgCySvzsQu7tasR38yaULny13h3mtp2WRHPG5DrLyuBuF8Z01hSgRi5hGcVpgzTwBgV5-eMaSUCUo-ZDkqUsLQI6dVlaikCSKYZRb53HeexH0tB_R9PJJHY7mIr-rS76kkQEx9pLuVnheIH9Oc6zbdYWg-zWMijopA8Pg",
|
||||
|
||||
wantErr: "received invalid ID token: access token hash does not match value in ID token",
|
||||
wantHTTPStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invalid nonce",
|
||||
query: "state=test-state&code=valid",
|
||||
|
||||
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
|
||||
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "invalid-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
|
||||
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjI4Mzc0MSwianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDIyODM3NDEsIm5vbmNlIjoiaW52YWxpZC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.PRpq-7j5djaIAkraL-8t8ad9Xm4hM8RW67gyD1VIe0BecWeBFxsTuh3SZVKM9zmcwTgjudsyn8kQOwipDa49IN4PV8FcJA_uUJZi2wiqGJUSTG2K5I89doV_7e0RM1ZYIDDW1G2heKJNW7MbKkX7iEPr7u4MyEzswcPcupbyDA-CQFeL95vgwawoqa6yO94ympTbozqiNfj6Xyw_nHtThQnstjWsJZ9s2mUgppZezZv4HZYTQ7c3e_bzwhWgCzh2CSDJn9_Ra_n_4GcVkpHbsHTP35dFsnf0vactPx6CAu6A1-Apk-BruCktpZ3B4Ercf1UnUOHdGqzQKJtqvB03xQ",
|
||||
|
||||
wantHTTPStatus: http.StatusBadRequest,
|
||||
wantErr: `received ID token with invalid nonce: invalid nonce (expected "test-nonce", got "invalid-nonce")`,
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
|
||||
mock := mockUpstream(t)
|
||||
mock.EXPECT().
|
||||
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")).
|
||||
Return(oidctypes.Token{}, nil, fmt.Errorf("some exchange error"))
|
||||
return mock
|
||||
}
|
||||
return nil
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
query: "state=test-state&code=valid",
|
||||
|
||||
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
|
||||
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "test-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
|
||||
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjUzMTU2NywianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDI1MzE1NjcsIm5vbmNlIjoidGVzdC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.LbOA31iwJZBM4ayY5Oud-HArLXbmtAIhZv_LazDqbzA2Iw87RxoBemfiPUJeAesdnO1LKSjBwbltZwtjvbLWHp1R5tqrSMr_hl2OyZv1cpEX-9QaTcQILJ5qR00riRLz34ZCQFyF-FfQpP1r4dNqFrxHuiBwKuPE7zogc83ZYJgAQM5Fao9rIRY9JStL_3pURa9JnnSHFlkLvFYv3TKEUyvnW4pWvYZcsGI7mys43vuSjpG7ZSrW3vCxovuIpXYqAhamZL_XexWUsXvi3ej9HNlhnhOFhN4fuPSc0PWDWaN0CLWmoo8gvOdQWo5A4GD4bNGBzjYOd-pYqsDfseRt1Q",
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider) provider.UpstreamOIDCIdentityProviderI {
|
||||
mock := mockUpstream(t)
|
||||
mock.EXPECT().
|
||||
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")).
|
||||
Return(oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil, nil)
|
||||
return mock
|
||||
}
|
||||
return nil
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(t, http.MethodPost, r.Method)
|
||||
require.NoError(t, r.ParseForm())
|
||||
require.Equal(t, "test-client-id", r.Form.Get("client_id"))
|
||||
require.Equal(t, "test-pkce", r.Form.Get("code_verifier"))
|
||||
require.Equal(t, "authorization_code", r.Form.Get("grant_type"))
|
||||
require.NotEmpty(t, r.Form.Get("code"))
|
||||
if r.Form.Get("code") != "valid" {
|
||||
http.Error(w, "invalid authorization code", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
var response struct {
|
||||
oauth2.Token
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
}
|
||||
response.AccessToken = "test-access-token"
|
||||
response.Expiry = time.Now().Add(time.Hour)
|
||||
response.IDToken = tt.returnIDTok
|
||||
w.Header().Set("content-type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(w).Encode(&response))
|
||||
}))
|
||||
t.Cleanup(tokenServer.Close)
|
||||
|
||||
h := &handlerState{
|
||||
callbacks: make(chan callbackResult, 1),
|
||||
state: state.State("test-state"),
|
||||
pkce: pkce.Code("test-pkce"),
|
||||
nonce: nonce.Nonce("test-nonce"),
|
||||
oauth2Config: &oauth2.Config{
|
||||
ClientID: "test-client-id",
|
||||
RedirectURL: "http://localhost:12345/callback",
|
||||
Endpoint: oauth2.Endpoint{
|
||||
TokenURL: tokenServer.URL,
|
||||
AuthStyle: oauth2.AuthStyleInParams,
|
||||
},
|
||||
},
|
||||
idTokenVerifier: mockVerifier(),
|
||||
}
|
||||
if tt.opt != nil {
|
||||
require.NoError(t, tt.opt(t)(h))
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
@ -653,34 +596,34 @@ func TestHandleAuthCodeCallback(t *testing.T) {
|
||||
}
|
||||
require.NoError(t, result.err)
|
||||
require.NotNil(t, result.token)
|
||||
require.Equal(t, result.token.IDToken.Token, tt.returnIDTok)
|
||||
require.Equal(t, result.token.IDToken.Token, "test-id-token")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockVerifier returns an *oidc.IDTokenVerifier that validates any correctly serialized JWT without doing much else.
|
||||
func mockVerifier() *oidc.IDTokenVerifier {
|
||||
mockKeySet := mockkeyset.NewMockKeySet(gomock.NewController(nil))
|
||||
mockKeySet.EXPECT().VerifySignature(gomock.Any(), gomock.Any()).
|
||||
AnyTimes().
|
||||
DoAndReturn(func(ctx context.Context, jwt string) ([]byte, error) {
|
||||
jws, err := jose.ParseSigned(jwt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return jws.UnsafePayloadWithoutVerification(), nil
|
||||
})
|
||||
|
||||
return oidc.NewVerifier("", mockKeySet, &oidc.Config{
|
||||
SkipIssuerCheck: true,
|
||||
SkipExpiryCheck: true,
|
||||
SkipClientIDCheck: true,
|
||||
})
|
||||
func mockUpstream(t *testing.T) *mockupstreamoidcidentityprovider.MockUpstreamOIDCIdentityProviderI {
|
||||
t.Helper()
|
||||
ctrl := gomock.NewController(t)
|
||||
t.Cleanup(ctrl.Finish)
|
||||
return mockupstreamoidcidentityprovider.NewMockUpstreamOIDCIdentityProviderI(ctrl)
|
||||
}
|
||||
|
||||
type mockDiscovery struct{ provider *oidc.Provider }
|
||||
// hasAccessTokenMatcher is a gomock.Matcher that expects an *oauth2.Token with a particular access token.
|
||||
type hasAccessTokenMatcher struct{ expected string }
|
||||
|
||||
func (m *mockDiscovery) Endpoint() oauth2.Endpoint { return m.provider.Endpoint() }
|
||||
func (m hasAccessTokenMatcher) Matches(arg interface{}) bool {
|
||||
return arg.(*oauth2.Token).AccessToken == m.expected
|
||||
}
|
||||
|
||||
func (m *mockDiscovery) Verifier(config *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() }
|
||||
func (m hasAccessTokenMatcher) Got(got interface{}) string {
|
||||
return got.(*oauth2.Token).AccessToken
|
||||
}
|
||||
|
||||
func (m hasAccessTokenMatcher) String() string {
|
||||
return m.expected
|
||||
}
|
||||
|
||||
func HasAccessToken(expected string) gomock.Matcher {
|
||||
return hasAccessTokenMatcher{expected: expected}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user