Merge pull request #168 from mattmoyer/cli-session-refresh
Add support for refresh token flow in OIDC CLI client.
This commit is contained in:
commit
4c844ba334
@ -30,6 +30,10 @@ const (
|
||||
// This is non-zero to ensure that most of the time, your ID token won't expire in the middle of a multi-step k8s
|
||||
// API operation.
|
||||
minIDTokenValidity = 10 * time.Minute
|
||||
|
||||
// refreshTimeout is the amount of time allotted for OAuth2 refresh operations. Since these don't involve any
|
||||
// user interaction, they should always be roughly as fast as network latency.
|
||||
refreshTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type handlerState struct {
|
||||
@ -56,6 +60,7 @@ type handlerState struct {
|
||||
generatePKCE func() (pkce.Code, error)
|
||||
generateNonce func() (nonce.Nonce, error)
|
||||
openURL func(string) error
|
||||
oidcDiscover func(context.Context, string) (discoveryI, error)
|
||||
|
||||
callbacks chan callbackResult
|
||||
}
|
||||
@ -123,6 +128,11 @@ type nopCache struct{}
|
||||
func (*nopCache) GetToken(SessionCacheKey) *Token { return nil }
|
||||
func (*nopCache) PutToken(SessionCacheKey, *Token) {}
|
||||
|
||||
type discoveryI interface {
|
||||
Endpoint() oauth2.Endpoint
|
||||
Verifier(*oidc.Config) *oidc.IDTokenVerifier
|
||||
}
|
||||
|
||||
// Login performs an OAuth2/OIDC authorization code login using a localhost listener.
|
||||
func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
|
||||
h := handlerState{
|
||||
@ -140,6 +150,9 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
|
||||
generateNonce: nonce.Generate,
|
||||
generatePKCE: pkce.Generate,
|
||||
openURL: browser.OpenURL,
|
||||
oidcDiscover: func(ctx context.Context, iss string) (discoveryI, error) {
|
||||
return oidc.NewProvider(ctx, iss)
|
||||
},
|
||||
}
|
||||
for _, opt := range opts {
|
||||
if err := opt(&h); err != nil {
|
||||
@ -177,52 +190,52 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
|
||||
}
|
||||
|
||||
// If the ID token is still valid for a bit, return it immediately and skip the rest of the flow.
|
||||
if cached := h.cache.GetToken(cacheKey); cached != nil &&
|
||||
cached.IDToken != nil &&
|
||||
time.Until(cached.IDToken.Expiry.Time) > minIDTokenValidity {
|
||||
cached := h.cache.GetToken(cacheKey)
|
||||
if cached != nil && cached.IDToken != nil && time.Until(cached.IDToken.Expiry.Time) > minIDTokenValidity {
|
||||
return cached, nil
|
||||
}
|
||||
|
||||
// Perform OIDC discovery.
|
||||
provider, err := oidc.NewProvider(h.ctx, h.issuer)
|
||||
discovered, err := h.oidcDiscover(h.ctx, h.issuer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
|
||||
}
|
||||
h.idTokenVerifier = provider.Verifier(&oidc.Config{ClientID: h.clientID})
|
||||
|
||||
// Open a TCP listener.
|
||||
listener, err := net.Listen("tcp", h.listenAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not open callback listener: %w", err)
|
||||
}
|
||||
h.idTokenVerifier = discovered.Verifier(&oidc.Config{ClientID: h.clientID})
|
||||
|
||||
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
|
||||
h.oauth2Config = &oauth2.Config{
|
||||
ClientID: h.clientID,
|
||||
Endpoint: provider.Endpoint(),
|
||||
RedirectURL: (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: listener.Addr().String(),
|
||||
Path: h.callbackPath,
|
||||
}).String(),
|
||||
Endpoint: discovered.Endpoint(),
|
||||
Scopes: h.scopes,
|
||||
}
|
||||
|
||||
// Start a callback server in a background goroutine.
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))
|
||||
srv := http.Server{
|
||||
Handler: securityheader.Wrap(mux),
|
||||
BaseContext: func(_ net.Listener) context.Context { return h.ctx },
|
||||
// If there was a cached refresh token, attempt to use the refresh flow instead of a fresh login.
|
||||
if cached != nil && cached.RefreshToken != nil && cached.RefreshToken.Token != "" {
|
||||
freshToken, err := h.handleRefresh(ctx, cached.RefreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go func() { _ = srv.Serve(listener) }()
|
||||
defer func() {
|
||||
// Gracefully shut down the server, allowing up to 5 seconds for
|
||||
// clients to receive any in-flight responses.
|
||||
shutdownCtx, cancel := context.WithTimeout(h.ctx, 1*time.Second)
|
||||
_ = srv.Shutdown(shutdownCtx)
|
||||
cancel()
|
||||
}()
|
||||
// If we got a fresh token, we can update the cache and return it. Otherwise we fall through to the full refresh flow.
|
||||
if freshToken != nil {
|
||||
h.cache.PutToken(cacheKey, freshToken)
|
||||
return freshToken, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Open a TCP listener and update the OAuth2 redirect_uri to match (in case we are using an ephemeral port number).
|
||||
listener, err := net.Listen("tcp", h.listenAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not open callback listener: %w", err)
|
||||
}
|
||||
h.oauth2Config.RedirectURL = (&url.URL{
|
||||
Scheme: "http",
|
||||
Host: listener.Addr().String(),
|
||||
Path: h.callbackPath,
|
||||
}).String()
|
||||
|
||||
// Start a callback server in a background goroutine.
|
||||
shutdown := h.serve(listener)
|
||||
defer shutdown()
|
||||
|
||||
// Open the authorize URL in the users browser.
|
||||
authorizeURL := h.oauth2Config.AuthCodeURL(
|
||||
@ -249,6 +262,22 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *RefreshToken) (*Token, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, refreshTimeout)
|
||||
defer cancel()
|
||||
refreshSource := h.oauth2Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken.Token})
|
||||
|
||||
refreshed, err := refreshSource.Token()
|
||||
if err != nil {
|
||||
// Ignore errors during refresh, but return nil which will trigger the full login flow.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
|
||||
// some providers do not include one, so we skip the nonce validation here (but not other validations).
|
||||
return h.validateToken(ctx, refreshed, false)
|
||||
}
|
||||
|
||||
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {
|
||||
// If we return an error, also report it back over the channel to the main CLI thread.
|
||||
defer func() {
|
||||
@ -280,37 +309,64 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
|
||||
// Perform required validations on the returned ID token.
|
||||
idTok, hasIDTok := oauth2Tok.Extra("id_token").(string)
|
||||
if !hasIDTok {
|
||||
return httperr.New(http.StatusBadRequest, "received response missing ID token")
|
||||
}
|
||||
validated, err := h.idTokenVerifier.Verify(r.Context(), idTok)
|
||||
token, err := h.validateToken(r.Context(), oauth2Tok, true)
|
||||
if err != nil {
|
||||
return httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||
}
|
||||
if validated.AccessTokenHash != "" {
|
||||
if err := validated.VerifyAccessToken(oauth2Tok.AccessToken); err != nil {
|
||||
return httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||
}
|
||||
}
|
||||
if err := h.nonce.Validate(validated); err != nil {
|
||||
return httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
|
||||
return err
|
||||
}
|
||||
|
||||
h.callbacks <- callbackResult{token: &Token{
|
||||
h.callbacks <- callbackResult{token: token}
|
||||
_, _ = w.Write([]byte("you have been logged in and may now close this tab"))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, checkNonce bool) (*Token, error) {
|
||||
idTok, hasIDTok := tok.Extra("id_token").(string)
|
||||
if !hasIDTok {
|
||||
return nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
|
||||
}
|
||||
validated, err := h.idTokenVerifier.Verify(ctx, idTok)
|
||||
if err != nil {
|
||||
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||
}
|
||||
if validated.AccessTokenHash != "" {
|
||||
if err := validated.VerifyAccessToken(tok.AccessToken); err != nil {
|
||||
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
|
||||
}
|
||||
}
|
||||
if checkNonce {
|
||||
if err := h.nonce.Validate(validated); err != nil {
|
||||
return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
|
||||
}
|
||||
}
|
||||
return &Token{
|
||||
AccessToken: &AccessToken{
|
||||
Token: oauth2Tok.AccessToken,
|
||||
Type: oauth2Tok.TokenType,
|
||||
Expiry: metav1.NewTime(oauth2Tok.Expiry),
|
||||
Token: tok.AccessToken,
|
||||
Type: tok.TokenType,
|
||||
Expiry: metav1.NewTime(tok.Expiry),
|
||||
},
|
||||
RefreshToken: &RefreshToken{
|
||||
Token: oauth2Tok.RefreshToken,
|
||||
Token: tok.RefreshToken,
|
||||
},
|
||||
IDToken: &IDToken{
|
||||
Token: idTok,
|
||||
Expiry: metav1.NewTime(validated.Expiry),
|
||||
},
|
||||
}}
|
||||
_, _ = w.Write([]byte("you have been logged in and may now close this tab"))
|
||||
return nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *handlerState) serve(listener net.Listener) func() {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))
|
||||
srv := http.Server{
|
||||
Handler: securityheader.Wrap(mux),
|
||||
BaseContext: func(_ net.Listener) context.Context { return h.ctx },
|
||||
}
|
||||
go func() { _ = srv.Serve(listener) }()
|
||||
return func() {
|
||||
// Gracefully shut down the server, allowing up to 5 seconds for
|
||||
// clients to receive any in-flight responses.
|
||||
shutdownCtx, cancel := context.WithTimeout(h.ctx, 1*time.Second)
|
||||
_ = srv.Shutdown(shutdownCtx)
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
@ -49,7 +50,10 @@ func (m *mockSessionCache) PutToken(key SessionCacheKey, token *Token) {
|
||||
}
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
time1 := time.Date(3020, 10, 12, 13, 14, 15, 16, time.UTC)
|
||||
time1 := time.Date(2035, 10, 12, 13, 14, 15, 16, time.UTC)
|
||||
time1Unix := int64(2075807775)
|
||||
require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix())
|
||||
|
||||
testToken := Token{
|
||||
AccessToken: &AccessToken{
|
||||
Token: "test-access-token",
|
||||
@ -59,7 +63,9 @@ func TestLogin(t *testing.T) {
|
||||
Token: "test-refresh-token",
|
||||
},
|
||||
IDToken: &IDToken{
|
||||
Token: "test-id-token",
|
||||
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/ (using time1Unix from above):
|
||||
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" --exp 2075807775
|
||||
Token: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImV4cCI6MjA3NTgwNzc3NSwiaWF0IjoxNjAzMzk5NTY4LCJpc3MiOiJ0ZXN0LWlzc3VlciIsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAzMzk5NTY4LCJzdWIiOiJ0ZXN0LXVzZXIifQ.CdwUWQb6xELeFlC4u84K4rzks7YiDJiXxIo_SaRvCHBijxtil812RBRfPuAyYKJlGwFx1g-JYvkUg69X5NmvmLXkaOdHIKUAT7Nqa7yqd1xOAP9IlFj9qZM3Q7s8gWWW9da-_ryagzN4fyGfNfYeGhzIriSMaVpuBGz1eg6f-6VuuulnoiOpl8A0l50u0MdRjjsxRHuiR2loIhUxoIQQ9xN8w53UiP0R1uz8_uV0_K93RSq37aPjsnCXRLwUUb3azkRVe6B9EUW1ihthQ-KfRaU1iq2rY1m5UqNzf0NqDXCrN5SF-GVxOhKXJTsN4-PABfJBjqxg6dGUGeIa2JhFcA",
|
||||
Expiry: metav1.NewTime(time1.Add(2 * time.Minute)),
|
||||
},
|
||||
}
|
||||
@ -70,11 +76,15 @@ func TestLogin(t *testing.T) {
|
||||
}))
|
||||
t.Cleanup(errorServer.Close)
|
||||
|
||||
// Start a test server that returns a real keyset
|
||||
// Start a test server that returns a real keyset and answers refresh requests.
|
||||
providerMux := http.NewServeMux()
|
||||
successServer := httptest.NewServer(providerMux)
|
||||
t.Cleanup(successServer.Close)
|
||||
providerMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
w.Header().Set("content-type", "application/json")
|
||||
type providerJSON struct {
|
||||
Issuer string `json:"issuer"`
|
||||
@ -89,6 +99,44 @@ func TestLogin(t *testing.T) {
|
||||
JWKSURL: successServer.URL + "/keys",
|
||||
})
|
||||
})
|
||||
providerMux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if r.Form.Get("client_id") != "test-client-id" {
|
||||
http.Error(w, "expected client_id 'test-client-id'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if r.Form.Get("grant_type") != "refresh_token" {
|
||||
http.Error(w, "expected refresh_token grant type", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var response struct {
|
||||
oauth2.Token
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
}
|
||||
response.AccessToken = testToken.AccessToken.Token
|
||||
response.ExpiresIn = int64(time.Until(testToken.AccessToken.Expiry.Time).Seconds())
|
||||
response.RefreshToken = testToken.RefreshToken.Token
|
||||
response.IDToken = testToken.IDToken.Token
|
||||
|
||||
if r.Form.Get("refresh_token") == "test-refresh-token-returning-invalid-id-token" {
|
||||
response.IDToken = "not a valid JWT"
|
||||
} else if r.Form.Get("refresh_token") != "test-refresh-token" {
|
||||
http.Error(w, "expected refresh_token to be 'test-refresh-token'", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("content-type", "application/json")
|
||||
require.NoError(t, json.NewEncoder(w).Encode(&response))
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
@ -192,6 +240,106 @@ func TestLogin(t *testing.T) {
|
||||
issuer: errorServer.URL,
|
||||
wantErr: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL),
|
||||
},
|
||||
{
|
||||
name: "session cache hit with refreshable token",
|
||||
issuer: successServer.URL,
|
||||
clientID: "test-client-id",
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
|
||||
IDToken: &IDToken{
|
||||
Token: "expired-test-id-token",
|
||||
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
|
||||
},
|
||||
RefreshToken: &RefreshToken{Token: "test-refresh-token"},
|
||||
}}
|
||||
t.Cleanup(func() {
|
||||
cacheKey := SessionCacheKey{
|
||||
Issuer: successServer.URL,
|
||||
ClientID: "test-client-id",
|
||||
Scopes: []string{"test-scope"},
|
||||
RedirectURI: "http://localhost:0/callback",
|
||||
}
|
||||
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
|
||||
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
|
||||
require.Len(t, cache.sawPutTokens, 1)
|
||||
require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token)
|
||||
})
|
||||
h.cache = cache
|
||||
|
||||
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
|
||||
provider, err := oidc.NewProvider(ctx, iss)
|
||||
require.NoError(t, err)
|
||||
return &mockDiscovery{provider: provider}, nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
},
|
||||
wantToken: &testToken,
|
||||
},
|
||||
{
|
||||
name: "session cache hit but refresh returns invalid token",
|
||||
issuer: successServer.URL,
|
||||
clientID: "test-client-id",
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
|
||||
IDToken: &IDToken{
|
||||
Token: "expired-test-id-token",
|
||||
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
|
||||
},
|
||||
RefreshToken: &RefreshToken{Token: "test-refresh-token-returning-invalid-id-token"},
|
||||
}}
|
||||
t.Cleanup(func() {
|
||||
require.Empty(t, cache.sawPutKeys)
|
||||
require.Empty(t, cache.sawPutTokens)
|
||||
})
|
||||
h.cache = cache
|
||||
|
||||
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
|
||||
provider, err := oidc.NewProvider(ctx, iss)
|
||||
require.NoError(t, err)
|
||||
return &mockDiscovery{provider: provider}, nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
},
|
||||
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
|
||||
},
|
||||
{
|
||||
name: "session cache hit but refresh fails",
|
||||
issuer: successServer.URL,
|
||||
clientID: "not-the-test-client-id",
|
||||
opt: func(t *testing.T) Option {
|
||||
return func(h *handlerState) error {
|
||||
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
|
||||
IDToken: &IDToken{
|
||||
Token: "expired-test-id-token",
|
||||
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
|
||||
},
|
||||
RefreshToken: &RefreshToken{Token: "test-refresh-token"},
|
||||
}}
|
||||
t.Cleanup(func() {
|
||||
require.Empty(t, cache.sawPutKeys)
|
||||
require.Empty(t, cache.sawPutTokens)
|
||||
})
|
||||
h.cache = cache
|
||||
|
||||
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
|
||||
provider, err := oidc.NewProvider(ctx, iss)
|
||||
require.NoError(t, err)
|
||||
return &mockDiscovery{provider: provider}, nil
|
||||
}
|
||||
|
||||
h.listenAddr = "invalid-listen-address"
|
||||
|
||||
return nil
|
||||
}
|
||||
},
|
||||
// Expect this to fall through to the authorization code flow, so it fails here.
|
||||
wantErr: "could not open callback listener: listen tcp: address invalid-listen-address: missing port in address",
|
||||
},
|
||||
{
|
||||
name: "listen failure",
|
||||
opt: func(t *testing.T) Option {
|
||||
@ -320,7 +468,30 @@ func TestLogin(t *testing.T) {
|
||||
require.Nil(t, tok)
|
||||
return
|
||||
}
|
||||
require.Equal(t, tt.wantToken, tok)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.wantToken == nil {
|
||||
require.Nil(t, tok)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, tok)
|
||||
|
||||
if want := tt.wantToken.AccessToken; want != nil {
|
||||
require.NotNil(t, tok.AccessToken)
|
||||
require.Equal(t, want.Token, tok.AccessToken.Token)
|
||||
require.Equal(t, want.Type, tok.AccessToken.Type)
|
||||
requireTimeInDelta(t, want.Expiry.Time, tok.AccessToken.Expiry.Time, 5*time.Second)
|
||||
} else {
|
||||
assert.Nil(t, tok.AccessToken)
|
||||
}
|
||||
require.Equal(t, tt.wantToken.RefreshToken, tok.RefreshToken)
|
||||
if want := tt.wantToken.IDToken; want != nil {
|
||||
require.NotNil(t, tok.IDToken)
|
||||
require.Equal(t, want.Token, tok.IDToken.Token)
|
||||
requireTimeInDelta(t, want.Expiry.Time, tok.IDToken.Expiry.Time, 5*time.Second)
|
||||
} else {
|
||||
assert.Nil(t, tok.IDToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -504,3 +675,22 @@ func mockVerifier() *oidc.IDTokenVerifier {
|
||||
SkipClientIDCheck: true,
|
||||
})
|
||||
}
|
||||
|
||||
type mockDiscovery struct{ provider *oidc.Provider }
|
||||
|
||||
func (m *mockDiscovery) Endpoint() oauth2.Endpoint { return m.provider.Endpoint() }
|
||||
|
||||
func (m *mockDiscovery) Verifier(config *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() }
|
||||
|
||||
func requireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Duration) {
|
||||
require.InDeltaf(t,
|
||||
float64(t1.UnixNano()),
|
||||
float64(t2.UnixNano()),
|
||||
float64(delta.Nanoseconds()),
|
||||
"expected %s and %s to be < %s apart, but they are %s apart",
|
||||
t1.Format(time.RFC3339Nano),
|
||||
t2.Format(time.RFC3339Nano),
|
||||
delta.String(),
|
||||
t1.Sub(t2).String(),
|
||||
)
|
||||
}
|
||||
|
@ -26,6 +26,8 @@ import (
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
clientauthenticationv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1"
|
||||
|
||||
"go.pinniped.dev/internal/oidcclient"
|
||||
"go.pinniped.dev/internal/oidcclient/filesession"
|
||||
"go.pinniped.dev/test/library"
|
||||
)
|
||||
|
||||
@ -176,18 +178,12 @@ func TestCLILoginOIDC(t *testing.T) {
|
||||
sessionCachePath := t.TempDir() + "/sessions.yaml"
|
||||
|
||||
// Start the CLI running the "alpha login oidc [...]" command with stdout/stderr connected to pipes.
|
||||
t.Logf("starting CLI subprocess")
|
||||
cmd := exec.CommandContext(ctx, pinnipedExe, "alpha", "login", "oidc",
|
||||
"--issuer", env.OIDCUpstream.Issuer,
|
||||
"--client-id", env.OIDCUpstream.ClientID,
|
||||
"--listen-port", strconv.Itoa(env.OIDCUpstream.LocalhostPort),
|
||||
"--session-cache", sessionCachePath,
|
||||
"--skip-browser",
|
||||
)
|
||||
cmd := oidcLoginCommand(ctx, t, pinnipedExe, sessionCachePath)
|
||||
stderr, err := cmd.StderrPipe()
|
||||
require.NoError(t, err)
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
require.NoError(t, err)
|
||||
t.Logf("starting CLI subprocess")
|
||||
require.NoError(t, cmd.Start())
|
||||
t.Cleanup(func() {
|
||||
err := cmd.Wait()
|
||||
@ -312,23 +308,44 @@ func TestCLILoginOIDC(t *testing.T) {
|
||||
|
||||
// Run the CLI again with the same session cache and login parameters.
|
||||
t.Logf("starting second CLI subprocess to test session caching")
|
||||
secondCtx, secondCancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer secondCancel()
|
||||
cmdOutput, err := exec.CommandContext(secondCtx, pinnipedExe, "alpha", "login", "oidc",
|
||||
"--issuer", env.OIDCUpstream.Issuer,
|
||||
"--client-id", env.OIDCUpstream.ClientID,
|
||||
"--listen-port", strconv.Itoa(env.OIDCUpstream.LocalhostPort),
|
||||
"--session-cache", sessionCachePath,
|
||||
"--skip-browser",
|
||||
).CombinedOutput()
|
||||
require.NoError(t, err)
|
||||
cmd2Output, err := oidcLoginCommand(ctx, t, pinnipedExe, sessionCachePath).CombinedOutput()
|
||||
require.NoError(t, err, string(cmd2Output))
|
||||
|
||||
// Expect the CLI to output the same ExecCredential in JSON format.
|
||||
t.Logf("validating second ExecCredential")
|
||||
var credOutput2 clientauthenticationv1beta1.ExecCredential
|
||||
require.NoErrorf(t, json.Unmarshal(cmdOutput, &credOutput2),
|
||||
"command returned something other than an ExecCredential:\n%s", string(cmdOutput))
|
||||
require.NoErrorf(t, json.Unmarshal(cmd2Output, &credOutput2),
|
||||
"command returned something other than an ExecCredential:\n%s", string(cmd2Output))
|
||||
require.Equal(t, credOutput, credOutput2)
|
||||
|
||||
// Overwrite the cache entry to remove the access and ID tokens.
|
||||
t.Logf("overwriting cache to remove valid ID token")
|
||||
cache := filesession.New(sessionCachePath)
|
||||
cacheKey := oidcclient.SessionCacheKey{
|
||||
Issuer: env.OIDCUpstream.Issuer,
|
||||
ClientID: env.OIDCUpstream.ClientID,
|
||||
Scopes: []string{"email", "offline_access", "openid", "profile"},
|
||||
RedirectURI: fmt.Sprintf("http://localhost:%d/callback", env.OIDCUpstream.LocalhostPort),
|
||||
}
|
||||
cached := cache.GetToken(cacheKey)
|
||||
require.NotNil(t, cached)
|
||||
require.NotNil(t, cached.RefreshToken)
|
||||
require.NotEmpty(t, cached.RefreshToken.Token)
|
||||
cached.IDToken = nil
|
||||
cached.AccessToken = nil
|
||||
cache.PutToken(cacheKey, cached)
|
||||
|
||||
// Run the CLI a third time with the same session cache and login parameters.
|
||||
t.Logf("starting third CLI subprocess to test refresh flow")
|
||||
cmd3Output, err := oidcLoginCommand(ctx, t, pinnipedExe, sessionCachePath).CombinedOutput()
|
||||
require.NoError(t, err, string(cmd2Output))
|
||||
|
||||
// Expect the CLI to output a new ExecCredential in JSON format (different from the one returned the first two times).
|
||||
t.Logf("validating third ExecCredential")
|
||||
var credOutput3 clientauthenticationv1beta1.ExecCredential
|
||||
require.NoErrorf(t, json.Unmarshal(cmd3Output, &credOutput3),
|
||||
"command returned something other than an ExecCredential:\n%s", string(cmd2Output))
|
||||
require.NotEqual(t, credOutput2.Status.Token, credOutput3.Status.Token)
|
||||
}
|
||||
|
||||
func waitForVisibleElements(t *testing.T, page *agouti.Page, selectors ...string) {
|
||||
@ -375,3 +392,14 @@ func spawnTestGoroutine(t *testing.T, f func() error) {
|
||||
})
|
||||
eg.Go(f)
|
||||
}
|
||||
|
||||
func oidcLoginCommand(ctx context.Context, t *testing.T, pinnipedExe string, sessionCachePath string) *exec.Cmd {
|
||||
env := library.IntegrationEnv(t)
|
||||
return exec.CommandContext(ctx, pinnipedExe, "alpha", "login", "oidc",
|
||||
"--issuer", env.OIDCUpstream.Issuer,
|
||||
"--client-id", env.OIDCUpstream.ClientID,
|
||||
"--listen-port", strconv.Itoa(env.OIDCUpstream.LocalhostPort),
|
||||
"--session-cache", sessionCachePath,
|
||||
"--skip-browser",
|
||||
)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user