Implement refresh flow in ./internal/oidcclient package.
Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
parent
8ae04605ca
commit
3508a28369
@ -30,6 +30,10 @@ const (
|
|||||||
// This is non-zero to ensure that most of the time, your ID token won't expire in the middle of a multi-step k8s
|
// This is non-zero to ensure that most of the time, your ID token won't expire in the middle of a multi-step k8s
|
||||||
// API operation.
|
// API operation.
|
||||||
minIDTokenValidity = 10 * time.Minute
|
minIDTokenValidity = 10 * time.Minute
|
||||||
|
|
||||||
|
// refreshTimeout is the amount of time allotted for OAuth2 refresh operations. Since these don't involve any
|
||||||
|
// user interaction, they should always be roughly as fast as network latency.
|
||||||
|
refreshTimeout = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
type handlerState struct {
|
type handlerState struct {
|
||||||
@ -56,6 +60,7 @@ type handlerState struct {
|
|||||||
generatePKCE func() (pkce.Code, error)
|
generatePKCE func() (pkce.Code, error)
|
||||||
generateNonce func() (nonce.Nonce, error)
|
generateNonce func() (nonce.Nonce, error)
|
||||||
openURL func(string) error
|
openURL func(string) error
|
||||||
|
oidcDiscover func(context.Context, string) (discoveryI, error)
|
||||||
|
|
||||||
callbacks chan callbackResult
|
callbacks chan callbackResult
|
||||||
}
|
}
|
||||||
@ -123,6 +128,11 @@ type nopCache struct{}
|
|||||||
func (*nopCache) GetToken(SessionCacheKey) *Token { return nil }
|
func (*nopCache) GetToken(SessionCacheKey) *Token { return nil }
|
||||||
func (*nopCache) PutToken(SessionCacheKey, *Token) {}
|
func (*nopCache) PutToken(SessionCacheKey, *Token) {}
|
||||||
|
|
||||||
|
type discoveryI interface {
|
||||||
|
Endpoint() oauth2.Endpoint
|
||||||
|
Verifier(*oidc.Config) *oidc.IDTokenVerifier
|
||||||
|
}
|
||||||
|
|
||||||
// Login performs an OAuth2/OIDC authorization code login using a localhost listener.
|
// Login performs an OAuth2/OIDC authorization code login using a localhost listener.
|
||||||
func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
|
func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
|
||||||
h := handlerState{
|
h := handlerState{
|
||||||
@ -140,6 +150,9 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
|
|||||||
generateNonce: nonce.Generate,
|
generateNonce: nonce.Generate,
|
||||||
generatePKCE: pkce.Generate,
|
generatePKCE: pkce.Generate,
|
||||||
openURL: browser.OpenURL,
|
openURL: browser.OpenURL,
|
||||||
|
oidcDiscover: func(ctx context.Context, iss string) (discoveryI, error) {
|
||||||
|
return oidc.NewProvider(ctx, iss)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
if err := opt(&h); err != nil {
|
if err := opt(&h); err != nil {
|
||||||
@ -177,52 +190,52 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If the ID token is still valid for a bit, return it immediately and skip the rest of the flow.
|
// If the ID token is still valid for a bit, return it immediately and skip the rest of the flow.
|
||||||
if cached := h.cache.GetToken(cacheKey); cached != nil &&
|
cached := h.cache.GetToken(cacheKey)
|
||||||
cached.IDToken != nil &&
|
if cached != nil && cached.IDToken != nil && time.Until(cached.IDToken.Expiry.Time) > minIDTokenValidity {
|
||||||
time.Until(cached.IDToken.Expiry.Time) > minIDTokenValidity {
|
|
||||||
return cached, nil
|
return cached, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform OIDC discovery.
|
// Perform OIDC discovery.
|
||||||
provider, err := oidc.NewProvider(h.ctx, h.issuer)
|
discovered, err := h.oidcDiscover(h.ctx, h.issuer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
|
return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
|
||||||
}
|
}
|
||||||
h.idTokenVerifier = provider.Verifier(&oidc.Config{ClientID: h.clientID})
|
h.idTokenVerifier = discovered.Verifier(&oidc.Config{ClientID: h.clientID})
|
||||||
|
|
||||||
// Open a TCP listener.
|
|
||||||
listener, err := net.Listen("tcp", h.listenAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("could not open callback listener: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
|
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
|
||||||
h.oauth2Config = &oauth2.Config{
|
h.oauth2Config = &oauth2.Config{
|
||||||
ClientID: h.clientID,
|
ClientID: h.clientID,
|
||||||
Endpoint: provider.Endpoint(),
|
Endpoint: discovered.Endpoint(),
|
||||||
RedirectURL: (&url.URL{
|
|
||||||
Scheme: "http",
|
|
||||||
Host: listener.Addr().String(),
|
|
||||||
Path: h.callbackPath,
|
|
||||||
}).String(),
|
|
||||||
Scopes: h.scopes,
|
Scopes: h.scopes,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a callback server in a background goroutine.
|
// If there was a cached refresh token, attempt to use the refresh flow instead of a fresh login.
|
||||||
mux := http.NewServeMux()
|
if cached != nil && cached.RefreshToken != nil && cached.RefreshToken.Token != "" {
|
||||||
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))
|
freshToken, err := h.handleRefresh(ctx, cached.RefreshToken)
|
||||||
srv := http.Server{
|
if err != nil {
|
||||||
Handler: securityheader.Wrap(mux),
|
return nil, err
|
||||||
BaseContext: func(_ net.Listener) context.Context { return h.ctx },
|
|
||||||
}
|
}
|
||||||
go func() { _ = srv.Serve(listener) }()
|
// If we got a fresh token, we can update the cache and return it. Otherwise we fall through to the full refresh flow.
|
||||||
defer func() {
|
if freshToken != nil {
|
||||||
// Gracefully shut down the server, allowing up to 5 seconds for
|
h.cache.PutToken(cacheKey, freshToken)
|
||||||
// clients to receive any in-flight responses.
|
return freshToken, nil
|
||||||
shutdownCtx, cancel := context.WithTimeout(h.ctx, 1*time.Second)
|
}
|
||||||
_ = srv.Shutdown(shutdownCtx)
|
}
|
||||||
cancel()
|
|
||||||
}()
|
// Open a TCP listener and update the OAuth2 redirect_uri to match (in case we are using an ephemeral port number).
|
||||||
|
listener, err := net.Listen("tcp", h.listenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not open callback listener: %w", err)
|
||||||
|
}
|
||||||
|
h.oauth2Config.RedirectURL = (&url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: listener.Addr().String(),
|
||||||
|
Path: h.callbackPath,
|
||||||
|
}).String()
|
||||||
|
|
||||||
|
// Start a callback server in a background goroutine.
|
||||||
|
shutdown := h.serve(listener)
|
||||||
|
defer shutdown()
|
||||||
|
|
||||||
// Open the authorize URL in the users browser.
|
// Open the authorize URL in the users browser.
|
||||||
authorizeURL := h.oauth2Config.AuthCodeURL(
|
authorizeURL := h.oauth2Config.AuthCodeURL(
|
||||||
@ -249,6 +262,22 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *RefreshToken) (*Token, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, refreshTimeout)
|
||||||
|
defer cancel()
|
||||||
|
refreshSource := h.oauth2Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken.Token})
|
||||||
|
|
||||||
|
refreshed, err := refreshSource.Token()
|
||||||
|
if err != nil {
|
||||||
|
// Ignore errors during refresh, but return nil which will trigger the full login flow.
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
|
||||||
|
// some providers do not include one, so we skip the nonce validation here (but not other validations).
|
||||||
|
return h.validateToken(ctx, refreshed, false)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {
|
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {
|
||||||
// If we return an error, also report it back over the channel to the main CLI thread.
|
// If we return an error, also report it back over the channel to the main CLI thread.
|
||||||
defer func() {
|
defer func() {
|
||||||
@ -280,37 +309,64 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Perform required validations on the returned ID token.
|
// Perform required validations on the returned ID token.
|
||||||
idTok, hasIDTok := oauth2Tok.Extra("id_token").(string)
|
token, err := h.validateToken(r.Context(), oauth2Tok, true)
|
||||||
if !hasIDTok {
|
|
||||||
return httperr.New(http.StatusBadRequest, "received response missing ID token")
|
|
||||||
}
|
|
||||||
validated, err := h.idTokenVerifier.Verify(r.Context(), idTok)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
return err
|
||||||
}
|
|
||||||
if validated.AccessTokenHash != "" {
|
|
||||||
if err := validated.VerifyAccessToken(oauth2Tok.AccessToken); err != nil {
|
|
||||||
return httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err := h.nonce.Validate(validated); err != nil {
|
|
||||||
return httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
h.callbacks <- callbackResult{token: &Token{
|
h.callbacks <- callbackResult{token: token}
|
||||||
|
_, _ = w.Write([]byte("you have been logged in and may now close this tab"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, checkNonce bool) (*Token, error) {
|
||||||
|
idTok, hasIDTok := tok.Extra("id_token").(string)
|
||||||
|
if !hasIDTok {
|
||||||
|
return nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
|
||||||
|
}
|
||||||
|
validated, err := h.idTokenVerifier.Verify(ctx, idTok)
|
||||||
|
if err != nil {
|
||||||
|
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||||
|
}
|
||||||
|
if validated.AccessTokenHash != "" {
|
||||||
|
if err := validated.VerifyAccessToken(tok.AccessToken); err != nil {
|
||||||
|
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if checkNonce {
|
||||||
|
if err := h.nonce.Validate(validated); err != nil {
|
||||||
|
return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &Token{
|
||||||
AccessToken: &AccessToken{
|
AccessToken: &AccessToken{
|
||||||
Token: oauth2Tok.AccessToken,
|
Token: tok.AccessToken,
|
||||||
Type: oauth2Tok.TokenType,
|
Type: tok.TokenType,
|
||||||
Expiry: metav1.NewTime(oauth2Tok.Expiry),
|
Expiry: metav1.NewTime(tok.Expiry),
|
||||||
},
|
},
|
||||||
RefreshToken: &RefreshToken{
|
RefreshToken: &RefreshToken{
|
||||||
Token: oauth2Tok.RefreshToken,
|
Token: tok.RefreshToken,
|
||||||
},
|
},
|
||||||
IDToken: &IDToken{
|
IDToken: &IDToken{
|
||||||
Token: idTok,
|
Token: idTok,
|
||||||
Expiry: metav1.NewTime(validated.Expiry),
|
Expiry: metav1.NewTime(validated.Expiry),
|
||||||
},
|
},
|
||||||
}}
|
}, nil
|
||||||
_, _ = w.Write([]byte("you have been logged in and may now close this tab"))
|
}
|
||||||
return nil
|
|
||||||
|
func (h *handlerState) serve(listener net.Listener) func() {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))
|
||||||
|
srv := http.Server{
|
||||||
|
Handler: securityheader.Wrap(mux),
|
||||||
|
BaseContext: func(_ net.Listener) context.Context { return h.ctx },
|
||||||
|
}
|
||||||
|
go func() { _ = srv.Serve(listener) }()
|
||||||
|
return func() {
|
||||||
|
// Gracefully shut down the server, allowing up to 5 seconds for
|
||||||
|
// clients to receive any in-flight responses.
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(h.ctx, 1*time.Second)
|
||||||
|
_ = srv.Shutdown(shutdownCtx)
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
|
|
||||||
"github.com/coreos/go-oidc"
|
"github.com/coreos/go-oidc"
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"gopkg.in/square/go-jose.v2"
|
"gopkg.in/square/go-jose.v2"
|
||||||
@ -49,7 +50,10 @@ func (m *mockSessionCache) PutToken(key SessionCacheKey, token *Token) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestLogin(t *testing.T) {
|
func TestLogin(t *testing.T) {
|
||||||
time1 := time.Date(3020, 10, 12, 13, 14, 15, 16, time.UTC)
|
time1 := time.Date(2035, 10, 12, 13, 14, 15, 16, time.UTC)
|
||||||
|
time1Unix := int64(2075807775)
|
||||||
|
require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix())
|
||||||
|
|
||||||
testToken := Token{
|
testToken := Token{
|
||||||
AccessToken: &AccessToken{
|
AccessToken: &AccessToken{
|
||||||
Token: "test-access-token",
|
Token: "test-access-token",
|
||||||
@ -59,7 +63,9 @@ func TestLogin(t *testing.T) {
|
|||||||
Token: "test-refresh-token",
|
Token: "test-refresh-token",
|
||||||
},
|
},
|
||||||
IDToken: &IDToken{
|
IDToken: &IDToken{
|
||||||
Token: "test-id-token",
|
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/ (using time1Unix from above):
|
||||||
|
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" --exp 2075807775
|
||||||
|
Token: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImV4cCI6MjA3NTgwNzc3NSwiaWF0IjoxNjAzMzk5NTY4LCJpc3MiOiJ0ZXN0LWlzc3VlciIsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAzMzk5NTY4LCJzdWIiOiJ0ZXN0LXVzZXIifQ.CdwUWQb6xELeFlC4u84K4rzks7YiDJiXxIo_SaRvCHBijxtil812RBRfPuAyYKJlGwFx1g-JYvkUg69X5NmvmLXkaOdHIKUAT7Nqa7yqd1xOAP9IlFj9qZM3Q7s8gWWW9da-_ryagzN4fyGfNfYeGhzIriSMaVpuBGz1eg6f-6VuuulnoiOpl8A0l50u0MdRjjsxRHuiR2loIhUxoIQQ9xN8w53UiP0R1uz8_uV0_K93RSq37aPjsnCXRLwUUb3azkRVe6B9EUW1ihthQ-KfRaU1iq2rY1m5UqNzf0NqDXCrN5SF-GVxOhKXJTsN4-PABfJBjqxg6dGUGeIa2JhFcA",
|
||||||
Expiry: metav1.NewTime(time1.Add(2 * time.Minute)),
|
Expiry: metav1.NewTime(time1.Add(2 * time.Minute)),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -70,11 +76,15 @@ func TestLogin(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
t.Cleanup(errorServer.Close)
|
t.Cleanup(errorServer.Close)
|
||||||
|
|
||||||
// Start a test server that returns a real keyset
|
// Start a test server that returns a real keyset and answers refresh requests.
|
||||||
providerMux := http.NewServeMux()
|
providerMux := http.NewServeMux()
|
||||||
successServer := httptest.NewServer(providerMux)
|
successServer := httptest.NewServer(providerMux)
|
||||||
t.Cleanup(successServer.Close)
|
t.Cleanup(successServer.Close)
|
||||||
providerMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
providerMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
w.Header().Set("content-type", "application/json")
|
w.Header().Set("content-type", "application/json")
|
||||||
type providerJSON struct {
|
type providerJSON struct {
|
||||||
Issuer string `json:"issuer"`
|
Issuer string `json:"issuer"`
|
||||||
@ -89,6 +99,44 @@ func TestLogin(t *testing.T) {
|
|||||||
JWKSURL: successServer.URL + "/keys",
|
JWKSURL: successServer.URL + "/keys",
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
providerMux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Form.Get("client_id") != "test-client-id" {
|
||||||
|
http.Error(w, "expected client_id 'test-client-id'", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Form.Get("grant_type") != "refresh_token" {
|
||||||
|
http.Error(w, "expected refresh_token grant type", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var response struct {
|
||||||
|
oauth2.Token
|
||||||
|
IDToken string `json:"id_token,omitempty"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
}
|
||||||
|
response.AccessToken = testToken.AccessToken.Token
|
||||||
|
response.ExpiresIn = int64(time.Until(testToken.AccessToken.Expiry.Time).Seconds())
|
||||||
|
response.RefreshToken = testToken.RefreshToken.Token
|
||||||
|
response.IDToken = testToken.IDToken.Token
|
||||||
|
|
||||||
|
if r.Form.Get("refresh_token") == "test-refresh-token-returning-invalid-id-token" {
|
||||||
|
response.IDToken = "not a valid JWT"
|
||||||
|
} else if r.Form.Get("refresh_token") != "test-refresh-token" {
|
||||||
|
http.Error(w, "expected refresh_token to be 'test-refresh-token'", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("content-type", "application/json")
|
||||||
|
require.NoError(t, json.NewEncoder(w).Encode(&response))
|
||||||
|
})
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@ -192,6 +240,106 @@ func TestLogin(t *testing.T) {
|
|||||||
issuer: errorServer.URL,
|
issuer: errorServer.URL,
|
||||||
wantErr: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL),
|
wantErr: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "session cache hit with refreshable token",
|
||||||
|
issuer: successServer.URL,
|
||||||
|
clientID: "test-client-id",
|
||||||
|
opt: func(t *testing.T) Option {
|
||||||
|
return func(h *handlerState) error {
|
||||||
|
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
|
||||||
|
IDToken: &IDToken{
|
||||||
|
Token: "expired-test-id-token",
|
||||||
|
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
|
||||||
|
},
|
||||||
|
RefreshToken: &RefreshToken{Token: "test-refresh-token"},
|
||||||
|
}}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cacheKey := SessionCacheKey{
|
||||||
|
Issuer: successServer.URL,
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Scopes: []string{"test-scope"},
|
||||||
|
RedirectURI: "http://localhost:0/callback",
|
||||||
|
}
|
||||||
|
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
|
||||||
|
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
|
||||||
|
require.Len(t, cache.sawPutTokens, 1)
|
||||||
|
require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token)
|
||||||
|
})
|
||||||
|
h.cache = cache
|
||||||
|
|
||||||
|
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
|
||||||
|
provider, err := oidc.NewProvider(ctx, iss)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return &mockDiscovery{provider: provider}, nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantToken: &testToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session cache hit but refresh returns invalid token",
|
||||||
|
issuer: successServer.URL,
|
||||||
|
clientID: "test-client-id",
|
||||||
|
opt: func(t *testing.T) Option {
|
||||||
|
return func(h *handlerState) error {
|
||||||
|
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
|
||||||
|
IDToken: &IDToken{
|
||||||
|
Token: "expired-test-id-token",
|
||||||
|
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
|
||||||
|
},
|
||||||
|
RefreshToken: &RefreshToken{Token: "test-refresh-token-returning-invalid-id-token"},
|
||||||
|
}}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.Empty(t, cache.sawPutKeys)
|
||||||
|
require.Empty(t, cache.sawPutTokens)
|
||||||
|
})
|
||||||
|
h.cache = cache
|
||||||
|
|
||||||
|
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
|
||||||
|
provider, err := oidc.NewProvider(ctx, iss)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return &mockDiscovery{provider: provider}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "session cache hit but refresh fails",
|
||||||
|
issuer: successServer.URL,
|
||||||
|
clientID: "not-the-test-client-id",
|
||||||
|
opt: func(t *testing.T) Option {
|
||||||
|
return func(h *handlerState) error {
|
||||||
|
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
|
||||||
|
IDToken: &IDToken{
|
||||||
|
Token: "expired-test-id-token",
|
||||||
|
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
|
||||||
|
},
|
||||||
|
RefreshToken: &RefreshToken{Token: "test-refresh-token"},
|
||||||
|
}}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.Empty(t, cache.sawPutKeys)
|
||||||
|
require.Empty(t, cache.sawPutTokens)
|
||||||
|
})
|
||||||
|
h.cache = cache
|
||||||
|
|
||||||
|
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
|
||||||
|
provider, err := oidc.NewProvider(ctx, iss)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return &mockDiscovery{provider: provider}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
h.listenAddr = "invalid-listen-address"
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
// Expect this to fall through to the authorization code flow, so it fails here.
|
||||||
|
wantErr: "could not open callback listener: listen tcp: address invalid-listen-address: missing port in address",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "listen failure",
|
name: "listen failure",
|
||||||
opt: func(t *testing.T) Option {
|
opt: func(t *testing.T) Option {
|
||||||
@ -320,7 +468,30 @@ func TestLogin(t *testing.T) {
|
|||||||
require.Nil(t, tok)
|
require.Nil(t, tok)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
require.Equal(t, tt.wantToken, tok)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if tt.wantToken == nil {
|
||||||
|
require.Nil(t, tok)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NotNil(t, tok)
|
||||||
|
|
||||||
|
if want := tt.wantToken.AccessToken; want != nil {
|
||||||
|
require.NotNil(t, tok.AccessToken)
|
||||||
|
require.Equal(t, want.Token, tok.AccessToken.Token)
|
||||||
|
require.Equal(t, want.Type, tok.AccessToken.Type)
|
||||||
|
requireTimeInDelta(t, want.Expiry.Time, tok.AccessToken.Expiry.Time, 5*time.Second)
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, tok.AccessToken)
|
||||||
|
}
|
||||||
|
require.Equal(t, tt.wantToken.RefreshToken, tok.RefreshToken)
|
||||||
|
if want := tt.wantToken.IDToken; want != nil {
|
||||||
|
require.NotNil(t, tok.IDToken)
|
||||||
|
require.Equal(t, want.Token, tok.IDToken.Token)
|
||||||
|
requireTimeInDelta(t, want.Expiry.Time, tok.IDToken.Expiry.Time, 5*time.Second)
|
||||||
|
} else {
|
||||||
|
assert.Nil(t, tok.IDToken)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -504,3 +675,22 @@ func mockVerifier() *oidc.IDTokenVerifier {
|
|||||||
SkipClientIDCheck: true,
|
SkipClientIDCheck: true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockDiscovery struct{ provider *oidc.Provider }
|
||||||
|
|
||||||
|
func (m *mockDiscovery) Endpoint() oauth2.Endpoint { return m.provider.Endpoint() }
|
||||||
|
|
||||||
|
func (m *mockDiscovery) Verifier(config *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() }
|
||||||
|
|
||||||
|
func requireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Duration) {
|
||||||
|
require.InDeltaf(t,
|
||||||
|
float64(t1.UnixNano()),
|
||||||
|
float64(t2.UnixNano()),
|
||||||
|
float64(delta.Nanoseconds()),
|
||||||
|
"expected %s and %s to be < %s apart, but they are %s apart",
|
||||||
|
t1.Format(time.RFC3339Nano),
|
||||||
|
t2.Format(time.RFC3339Nano),
|
||||||
|
delta.String(),
|
||||||
|
t1.Sub(t2).String(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user