diff --git a/internal/oidcclient/login.go b/internal/oidcclient/login.go index 90e8c1a9..5ab3b5ff 100644 --- a/internal/oidcclient/login.go +++ b/internal/oidcclient/login.go @@ -30,6 +30,10 @@ const ( // This is non-zero to ensure that most of the time, your ID token won't expire in the middle of a multi-step k8s // API operation. minIDTokenValidity = 10 * time.Minute + + // refreshTimeout is the amount of time allotted for OAuth2 refresh operations. Since these don't involve any + // user interaction, they should always be roughly as fast as network latency. + refreshTimeout = 30 * time.Second ) type handlerState struct { @@ -56,6 +60,7 @@ type handlerState struct { generatePKCE func() (pkce.Code, error) generateNonce func() (nonce.Nonce, error) openURL func(string) error + oidcDiscover func(context.Context, string) (discoveryI, error) callbacks chan callbackResult } @@ -123,6 +128,11 @@ type nopCache struct{} func (*nopCache) GetToken(SessionCacheKey) *Token { return nil } func (*nopCache) PutToken(SessionCacheKey, *Token) {} +type discoveryI interface { + Endpoint() oauth2.Endpoint + Verifier(*oidc.Config) *oidc.IDTokenVerifier +} + // Login performs an OAuth2/OIDC authorization code login using a localhost listener. func Login(issuer string, clientID string, opts ...Option) (*Token, error) { h := handlerState{ @@ -140,6 +150,9 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) { generateNonce: nonce.Generate, generatePKCE: pkce.Generate, openURL: browser.OpenURL, + oidcDiscover: func(ctx context.Context, iss string) (discoveryI, error) { + return oidc.NewProvider(ctx, iss) + }, } for _, opt := range opts { if err := opt(&h); err != nil { @@ -177,52 +190,52 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) { } // If the ID token is still valid for a bit, return it immediately and skip the rest of the flow. - if cached := h.cache.GetToken(cacheKey); cached != nil && - cached.IDToken != nil && - time.Until(cached.IDToken.Expiry.Time) > minIDTokenValidity { + cached := h.cache.GetToken(cacheKey) + if cached != nil && cached.IDToken != nil && time.Until(cached.IDToken.Expiry.Time) > minIDTokenValidity { return cached, nil } // Perform OIDC discovery. - provider, err := oidc.NewProvider(h.ctx, h.issuer) + discovered, err := h.oidcDiscover(h.ctx, h.issuer) if err != nil { return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err) } - h.idTokenVerifier = provider.Verifier(&oidc.Config{ClientID: h.clientID}) - - // Open a TCP listener. - listener, err := net.Listen("tcp", h.listenAddr) - if err != nil { - return nil, fmt.Errorf("could not open callback listener: %w", err) - } + h.idTokenVerifier = discovered.Verifier(&oidc.Config{ClientID: h.clientID}) // Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint. h.oauth2Config = &oauth2.Config{ ClientID: h.clientID, - Endpoint: provider.Endpoint(), - RedirectURL: (&url.URL{ - Scheme: "http", - Host: listener.Addr().String(), - Path: h.callbackPath, - }).String(), - Scopes: h.scopes, + Endpoint: discovered.Endpoint(), + Scopes: h.scopes, } - // Start a callback server in a background goroutine. - mux := http.NewServeMux() - mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback)) - srv := http.Server{ - Handler: securityheader.Wrap(mux), - BaseContext: func(_ net.Listener) context.Context { return h.ctx }, + // If there was a cached refresh token, attempt to use the refresh flow instead of a fresh login. + if cached != nil && cached.RefreshToken != nil && cached.RefreshToken.Token != "" { + freshToken, err := h.handleRefresh(ctx, cached.RefreshToken) + if err != nil { + return nil, err + } + // If we got a fresh token, we can update the cache and return it. Otherwise we fall through to the full refresh flow. + if freshToken != nil { + h.cache.PutToken(cacheKey, freshToken) + return freshToken, nil + } } - go func() { _ = srv.Serve(listener) }() - defer func() { - // Gracefully shut down the server, allowing up to 5 seconds for - // clients to receive any in-flight responses. - shutdownCtx, cancel := context.WithTimeout(h.ctx, 1*time.Second) - _ = srv.Shutdown(shutdownCtx) - cancel() - }() + + // Open a TCP listener and update the OAuth2 redirect_uri to match (in case we are using an ephemeral port number). + listener, err := net.Listen("tcp", h.listenAddr) + if err != nil { + return nil, fmt.Errorf("could not open callback listener: %w", err) + } + h.oauth2Config.RedirectURL = (&url.URL{ + Scheme: "http", + Host: listener.Addr().String(), + Path: h.callbackPath, + }).String() + + // Start a callback server in a background goroutine. + shutdown := h.serve(listener) + defer shutdown() // Open the authorize URL in the users browser. authorizeURL := h.oauth2Config.AuthCodeURL( @@ -249,6 +262,22 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) { } } +func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *RefreshToken) (*Token, error) { + ctx, cancel := context.WithTimeout(ctx, refreshTimeout) + defer cancel() + refreshSource := h.oauth2Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken.Token}) + + refreshed, err := refreshSource.Token() + if err != nil { + // Ignore errors during refresh, but return nil which will trigger the full login flow. + return nil, nil + } + + // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least + // some providers do not include one, so we skip the nonce validation here (but not other validations). + return h.validateToken(ctx, refreshed, false) +} + func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) { // If we return an error, also report it back over the channel to the main CLI thread. defer func() { @@ -280,37 +309,64 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req } // Perform required validations on the returned ID token. - idTok, hasIDTok := oauth2Tok.Extra("id_token").(string) - if !hasIDTok { - return httperr.New(http.StatusBadRequest, "received response missing ID token") - } - validated, err := h.idTokenVerifier.Verify(r.Context(), idTok) + token, err := h.validateToken(r.Context(), oauth2Tok, true) if err != nil { - return httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) - } - if validated.AccessTokenHash != "" { - if err := validated.VerifyAccessToken(oauth2Tok.AccessToken); err != nil { - return httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) - } - } - if err := h.nonce.Validate(validated); err != nil { - return httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err) + return err } - h.callbacks <- callbackResult{token: &Token{ + h.callbacks <- callbackResult{token: token} + _, _ = w.Write([]byte("you have been logged in and may now close this tab")) + return nil +} + +func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, checkNonce bool) (*Token, error) { + idTok, hasIDTok := tok.Extra("id_token").(string) + if !hasIDTok { + return nil, httperr.New(http.StatusBadRequest, "received response missing ID token") + } + validated, err := h.idTokenVerifier.Verify(ctx, idTok) + if err != nil { + return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) + } + if validated.AccessTokenHash != "" { + if err := validated.VerifyAccessToken(tok.AccessToken); err != nil { + return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) + } + } + if checkNonce { + if err := h.nonce.Validate(validated); err != nil { + return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err) + } + } + return &Token{ AccessToken: &AccessToken{ - Token: oauth2Tok.AccessToken, - Type: oauth2Tok.TokenType, - Expiry: metav1.NewTime(oauth2Tok.Expiry), + Token: tok.AccessToken, + Type: tok.TokenType, + Expiry: metav1.NewTime(tok.Expiry), }, RefreshToken: &RefreshToken{ - Token: oauth2Tok.RefreshToken, + Token: tok.RefreshToken, }, IDToken: &IDToken{ Token: idTok, Expiry: metav1.NewTime(validated.Expiry), }, - }} - _, _ = w.Write([]byte("you have been logged in and may now close this tab")) - return nil + }, nil +} + +func (h *handlerState) serve(listener net.Listener) func() { + mux := http.NewServeMux() + mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback)) + srv := http.Server{ + Handler: securityheader.Wrap(mux), + BaseContext: func(_ net.Listener) context.Context { return h.ctx }, + } + go func() { _ = srv.Serve(listener) }() + return func() { + // Gracefully shut down the server, allowing up to 5 seconds for + // clients to receive any in-flight responses. + shutdownCtx, cancel := context.WithTimeout(h.ctx, 1*time.Second) + _ = srv.Shutdown(shutdownCtx) + cancel() + } } diff --git a/internal/oidcclient/login_test.go b/internal/oidcclient/login_test.go index 1271b98d..ea12737f 100644 --- a/internal/oidcclient/login_test.go +++ b/internal/oidcclient/login_test.go @@ -15,6 +15,7 @@ import ( "github.com/coreos/go-oidc" "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/oauth2" "gopkg.in/square/go-jose.v2" @@ -49,7 +50,10 @@ func (m *mockSessionCache) PutToken(key SessionCacheKey, token *Token) { } func TestLogin(t *testing.T) { - time1 := time.Date(3020, 10, 12, 13, 14, 15, 16, time.UTC) + time1 := time.Date(2035, 10, 12, 13, 14, 15, 16, time.UTC) + time1Unix := int64(2075807775) + require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix()) + testToken := Token{ AccessToken: &AccessToken{ Token: "test-access-token", @@ -59,7 +63,9 @@ func TestLogin(t *testing.T) { Token: "test-refresh-token", }, IDToken: &IDToken{ - Token: "test-id-token", + // Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/ (using time1Unix from above): + // step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" --exp 2075807775 + Token: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImV4cCI6MjA3NTgwNzc3NSwiaWF0IjoxNjAzMzk5NTY4LCJpc3MiOiJ0ZXN0LWlzc3VlciIsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAzMzk5NTY4LCJzdWIiOiJ0ZXN0LXVzZXIifQ.CdwUWQb6xELeFlC4u84K4rzks7YiDJiXxIo_SaRvCHBijxtil812RBRfPuAyYKJlGwFx1g-JYvkUg69X5NmvmLXkaOdHIKUAT7Nqa7yqd1xOAP9IlFj9qZM3Q7s8gWWW9da-_ryagzN4fyGfNfYeGhzIriSMaVpuBGz1eg6f-6VuuulnoiOpl8A0l50u0MdRjjsxRHuiR2loIhUxoIQQ9xN8w53UiP0R1uz8_uV0_K93RSq37aPjsnCXRLwUUb3azkRVe6B9EUW1ihthQ-KfRaU1iq2rY1m5UqNzf0NqDXCrN5SF-GVxOhKXJTsN4-PABfJBjqxg6dGUGeIa2JhFcA", Expiry: metav1.NewTime(time1.Add(2 * time.Minute)), }, } @@ -70,11 +76,15 @@ func TestLogin(t *testing.T) { })) t.Cleanup(errorServer.Close) - // Start a test server that returns a real keyset + // Start a test server that returns a real keyset and answers refresh requests. providerMux := http.NewServeMux() successServer := httptest.NewServer(providerMux) t.Cleanup(successServer.Close) providerMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "unexpected method", http.StatusMethodNotAllowed) + return + } w.Header().Set("content-type", "application/json") type providerJSON struct { Issuer string `json:"issuer"` @@ -89,6 +99,44 @@ func TestLogin(t *testing.T) { JWKSURL: successServer.URL + "/keys", }) }) + providerMux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "unexpected method", http.StatusMethodNotAllowed) + return + } + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if r.Form.Get("client_id") != "test-client-id" { + http.Error(w, "expected client_id 'test-client-id'", http.StatusBadRequest) + return + } + if r.Form.Get("grant_type") != "refresh_token" { + http.Error(w, "expected refresh_token grant type", http.StatusBadRequest) + return + } + + var response struct { + oauth2.Token + IDToken string `json:"id_token,omitempty"` + ExpiresIn int64 `json:"expires_in"` + } + response.AccessToken = testToken.AccessToken.Token + response.ExpiresIn = int64(time.Until(testToken.AccessToken.Expiry.Time).Seconds()) + response.RefreshToken = testToken.RefreshToken.Token + response.IDToken = testToken.IDToken.Token + + if r.Form.Get("refresh_token") == "test-refresh-token-returning-invalid-id-token" { + response.IDToken = "not a valid JWT" + } else if r.Form.Get("refresh_token") != "test-refresh-token" { + http.Error(w, "expected refresh_token to be 'test-refresh-token'", http.StatusBadRequest) + return + } + + w.Header().Set("content-type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(&response)) + }) tests := []struct { name string @@ -192,6 +240,106 @@ func TestLogin(t *testing.T) { issuer: errorServer.URL, wantErr: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL), }, + { + name: "session cache hit with refreshable token", + issuer: successServer.URL, + clientID: "test-client-id", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + cache := &mockSessionCache{t: t, getReturnsToken: &Token{ + IDToken: &IDToken{ + Token: "expired-test-id-token", + Expiry: metav1.Now(), // less than Now() + minIDTokenValidity + }, + RefreshToken: &RefreshToken{Token: "test-refresh-token"}, + }} + t.Cleanup(func() { + cacheKey := SessionCacheKey{ + Issuer: successServer.URL, + ClientID: "test-client-id", + Scopes: []string{"test-scope"}, + RedirectURI: "http://localhost:0/callback", + } + require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys) + require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys) + require.Len(t, cache.sawPutTokens, 1) + require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token) + }) + h.cache = cache + + h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) { + provider, err := oidc.NewProvider(ctx, iss) + require.NoError(t, err) + return &mockDiscovery{provider: provider}, nil + } + return nil + } + }, + wantToken: &testToken, + }, + { + name: "session cache hit but refresh returns invalid token", + issuer: successServer.URL, + clientID: "test-client-id", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + cache := &mockSessionCache{t: t, getReturnsToken: &Token{ + IDToken: &IDToken{ + Token: "expired-test-id-token", + Expiry: metav1.Now(), // less than Now() + minIDTokenValidity + }, + RefreshToken: &RefreshToken{Token: "test-refresh-token-returning-invalid-id-token"}, + }} + t.Cleanup(func() { + require.Empty(t, cache.sawPutKeys) + require.Empty(t, cache.sawPutTokens) + }) + h.cache = cache + + h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) { + provider, err := oidc.NewProvider(ctx, iss) + require.NoError(t, err) + return &mockDiscovery{provider: provider}, nil + } + + return nil + } + }, + wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts", + }, + { + name: "session cache hit but refresh fails", + issuer: successServer.URL, + clientID: "not-the-test-client-id", + opt: func(t *testing.T) Option { + return func(h *handlerState) error { + cache := &mockSessionCache{t: t, getReturnsToken: &Token{ + IDToken: &IDToken{ + Token: "expired-test-id-token", + Expiry: metav1.Now(), // less than Now() + minIDTokenValidity + }, + RefreshToken: &RefreshToken{Token: "test-refresh-token"}, + }} + t.Cleanup(func() { + require.Empty(t, cache.sawPutKeys) + require.Empty(t, cache.sawPutTokens) + }) + h.cache = cache + + h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) { + provider, err := oidc.NewProvider(ctx, iss) + require.NoError(t, err) + return &mockDiscovery{provider: provider}, nil + } + + h.listenAddr = "invalid-listen-address" + + return nil + } + }, + // Expect this to fall through to the authorization code flow, so it fails here. + wantErr: "could not open callback listener: listen tcp: address invalid-listen-address: missing port in address", + }, { name: "listen failure", opt: func(t *testing.T) Option { @@ -320,7 +468,30 @@ func TestLogin(t *testing.T) { require.Nil(t, tok) return } - require.Equal(t, tt.wantToken, tok) + require.NoError(t, err) + + if tt.wantToken == nil { + require.Nil(t, tok) + return + } + require.NotNil(t, tok) + + if want := tt.wantToken.AccessToken; want != nil { + require.NotNil(t, tok.AccessToken) + require.Equal(t, want.Token, tok.AccessToken.Token) + require.Equal(t, want.Type, tok.AccessToken.Type) + requireTimeInDelta(t, want.Expiry.Time, tok.AccessToken.Expiry.Time, 5*time.Second) + } else { + assert.Nil(t, tok.AccessToken) + } + require.Equal(t, tt.wantToken.RefreshToken, tok.RefreshToken) + if want := tt.wantToken.IDToken; want != nil { + require.NotNil(t, tok.IDToken) + require.Equal(t, want.Token, tok.IDToken.Token) + requireTimeInDelta(t, want.Expiry.Time, tok.IDToken.Expiry.Time, 5*time.Second) + } else { + assert.Nil(t, tok.IDToken) + } }) } } @@ -504,3 +675,22 @@ func mockVerifier() *oidc.IDTokenVerifier { SkipClientIDCheck: true, }) } + +type mockDiscovery struct{ provider *oidc.Provider } + +func (m *mockDiscovery) Endpoint() oauth2.Endpoint { return m.provider.Endpoint() } + +func (m *mockDiscovery) Verifier(config *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() } + +func requireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Duration) { + require.InDeltaf(t, + float64(t1.UnixNano()), + float64(t2.UnixNano()), + float64(delta.Nanoseconds()), + "expected %s and %s to be < %s apart, but they are %s apart", + t1.Format(time.RFC3339Nano), + t2.Format(time.RFC3339Nano), + delta.String(), + t1.Sub(t2).String(), + ) +}