Implement refresh flow in ./internal/oidcclient package.

Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
Matt Moyer 2020-10-22 16:12:02 -05:00
parent 8ae04605ca
commit 3508a28369
No known key found for this signature in database
GPG Key ID: EAE88AD172C5AE2D
2 changed files with 304 additions and 58 deletions

View File

@ -30,6 +30,10 @@ const (
// This is non-zero to ensure that most of the time, your ID token won't expire in the middle of a multi-step k8s // This is non-zero to ensure that most of the time, your ID token won't expire in the middle of a multi-step k8s
// API operation. // API operation.
minIDTokenValidity = 10 * time.Minute minIDTokenValidity = 10 * time.Minute
// refreshTimeout is the amount of time allotted for OAuth2 refresh operations. Since these don't involve any
// user interaction, they should always be roughly as fast as network latency.
refreshTimeout = 30 * time.Second
) )
type handlerState struct { type handlerState struct {
@ -56,6 +60,7 @@ type handlerState struct {
generatePKCE func() (pkce.Code, error) generatePKCE func() (pkce.Code, error)
generateNonce func() (nonce.Nonce, error) generateNonce func() (nonce.Nonce, error)
openURL func(string) error openURL func(string) error
oidcDiscover func(context.Context, string) (discoveryI, error)
callbacks chan callbackResult callbacks chan callbackResult
} }
@ -123,6 +128,11 @@ type nopCache struct{}
func (*nopCache) GetToken(SessionCacheKey) *Token { return nil } func (*nopCache) GetToken(SessionCacheKey) *Token { return nil }
func (*nopCache) PutToken(SessionCacheKey, *Token) {} func (*nopCache) PutToken(SessionCacheKey, *Token) {}
type discoveryI interface {
Endpoint() oauth2.Endpoint
Verifier(*oidc.Config) *oidc.IDTokenVerifier
}
// Login performs an OAuth2/OIDC authorization code login using a localhost listener. // Login performs an OAuth2/OIDC authorization code login using a localhost listener.
func Login(issuer string, clientID string, opts ...Option) (*Token, error) { func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
h := handlerState{ h := handlerState{
@ -140,6 +150,9 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
generateNonce: nonce.Generate, generateNonce: nonce.Generate,
generatePKCE: pkce.Generate, generatePKCE: pkce.Generate,
openURL: browser.OpenURL, openURL: browser.OpenURL,
oidcDiscover: func(ctx context.Context, iss string) (discoveryI, error) {
return oidc.NewProvider(ctx, iss)
},
} }
for _, opt := range opts { for _, opt := range opts {
if err := opt(&h); err != nil { if err := opt(&h); err != nil {
@ -177,52 +190,52 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
} }
// If the ID token is still valid for a bit, return it immediately and skip the rest of the flow. // If the ID token is still valid for a bit, return it immediately and skip the rest of the flow.
if cached := h.cache.GetToken(cacheKey); cached != nil && cached := h.cache.GetToken(cacheKey)
cached.IDToken != nil && if cached != nil && cached.IDToken != nil && time.Until(cached.IDToken.Expiry.Time) > minIDTokenValidity {
time.Until(cached.IDToken.Expiry.Time) > minIDTokenValidity {
return cached, nil return cached, nil
} }
// Perform OIDC discovery. // Perform OIDC discovery.
provider, err := oidc.NewProvider(h.ctx, h.issuer) discovered, err := h.oidcDiscover(h.ctx, h.issuer)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err) return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
} }
h.idTokenVerifier = provider.Verifier(&oidc.Config{ClientID: h.clientID}) h.idTokenVerifier = discovered.Verifier(&oidc.Config{ClientID: h.clientID})
// Open a TCP listener.
listener, err := net.Listen("tcp", h.listenAddr)
if err != nil {
return nil, fmt.Errorf("could not open callback listener: %w", err)
}
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint. // Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
h.oauth2Config = &oauth2.Config{ h.oauth2Config = &oauth2.Config{
ClientID: h.clientID, ClientID: h.clientID,
Endpoint: provider.Endpoint(), Endpoint: discovered.Endpoint(),
RedirectURL: (&url.URL{
Scheme: "http",
Host: listener.Addr().String(),
Path: h.callbackPath,
}).String(),
Scopes: h.scopes, Scopes: h.scopes,
} }
// Start a callback server in a background goroutine. // If there was a cached refresh token, attempt to use the refresh flow instead of a fresh login.
mux := http.NewServeMux() if cached != nil && cached.RefreshToken != nil && cached.RefreshToken.Token != "" {
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback)) freshToken, err := h.handleRefresh(ctx, cached.RefreshToken)
srv := http.Server{ if err != nil {
Handler: securityheader.Wrap(mux), return nil, err
BaseContext: func(_ net.Listener) context.Context { return h.ctx },
} }
go func() { _ = srv.Serve(listener) }() // If we got a fresh token, we can update the cache and return it. Otherwise we fall through to the full refresh flow.
defer func() { if freshToken != nil {
// Gracefully shut down the server, allowing up to 5 seconds for h.cache.PutToken(cacheKey, freshToken)
// clients to receive any in-flight responses. return freshToken, nil
shutdownCtx, cancel := context.WithTimeout(h.ctx, 1*time.Second) }
_ = srv.Shutdown(shutdownCtx) }
cancel()
}() // Open a TCP listener and update the OAuth2 redirect_uri to match (in case we are using an ephemeral port number).
listener, err := net.Listen("tcp", h.listenAddr)
if err != nil {
return nil, fmt.Errorf("could not open callback listener: %w", err)
}
h.oauth2Config.RedirectURL = (&url.URL{
Scheme: "http",
Host: listener.Addr().String(),
Path: h.callbackPath,
}).String()
// Start a callback server in a background goroutine.
shutdown := h.serve(listener)
defer shutdown()
// Open the authorize URL in the users browser. // Open the authorize URL in the users browser.
authorizeURL := h.oauth2Config.AuthCodeURL( authorizeURL := h.oauth2Config.AuthCodeURL(
@ -249,6 +262,22 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
} }
} }
func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *RefreshToken) (*Token, error) {
ctx, cancel := context.WithTimeout(ctx, refreshTimeout)
defer cancel()
refreshSource := h.oauth2Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken.Token})
refreshed, err := refreshSource.Token()
if err != nil {
// Ignore errors during refresh, but return nil which will trigger the full login flow.
return nil, nil
}
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
// some providers do not include one, so we skip the nonce validation here (but not other validations).
return h.validateToken(ctx, refreshed, false)
}
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) { func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {
// If we return an error, also report it back over the channel to the main CLI thread. // If we return an error, also report it back over the channel to the main CLI thread.
defer func() { defer func() {
@ -280,37 +309,64 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
} }
// Perform required validations on the returned ID token. // Perform required validations on the returned ID token.
idTok, hasIDTok := oauth2Tok.Extra("id_token").(string) token, err := h.validateToken(r.Context(), oauth2Tok, true)
if !hasIDTok {
return httperr.New(http.StatusBadRequest, "received response missing ID token")
}
validated, err := h.idTokenVerifier.Verify(r.Context(), idTok)
if err != nil { if err != nil {
return httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err) return err
}
if validated.AccessTokenHash != "" {
if err := validated.VerifyAccessToken(oauth2Tok.AccessToken); err != nil {
return httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
}
if err := h.nonce.Validate(validated); err != nil {
return httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
} }
h.callbacks <- callbackResult{token: &Token{ h.callbacks <- callbackResult{token: token}
_, _ = w.Write([]byte("you have been logged in and may now close this tab"))
return nil
}
func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, checkNonce bool) (*Token, error) {
idTok, hasIDTok := tok.Extra("id_token").(string)
if !hasIDTok {
return nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
}
validated, err := h.idTokenVerifier.Verify(ctx, idTok)
if err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
if validated.AccessTokenHash != "" {
if err := validated.VerifyAccessToken(tok.AccessToken); err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
}
if checkNonce {
if err := h.nonce.Validate(validated); err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
}
}
return &Token{
AccessToken: &AccessToken{ AccessToken: &AccessToken{
Token: oauth2Tok.AccessToken, Token: tok.AccessToken,
Type: oauth2Tok.TokenType, Type: tok.TokenType,
Expiry: metav1.NewTime(oauth2Tok.Expiry), Expiry: metav1.NewTime(tok.Expiry),
}, },
RefreshToken: &RefreshToken{ RefreshToken: &RefreshToken{
Token: oauth2Tok.RefreshToken, Token: tok.RefreshToken,
}, },
IDToken: &IDToken{ IDToken: &IDToken{
Token: idTok, Token: idTok,
Expiry: metav1.NewTime(validated.Expiry), Expiry: metav1.NewTime(validated.Expiry),
}, },
}} }, nil
_, _ = w.Write([]byte("you have been logged in and may now close this tab")) }
return nil
func (h *handlerState) serve(listener net.Listener) func() {
mux := http.NewServeMux()
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))
srv := http.Server{
Handler: securityheader.Wrap(mux),
BaseContext: func(_ net.Listener) context.Context { return h.ctx },
}
go func() { _ = srv.Serve(listener) }()
return func() {
// Gracefully shut down the server, allowing up to 5 seconds for
// clients to receive any in-flight responses.
shutdownCtx, cancel := context.WithTimeout(h.ctx, 1*time.Second)
_ = srv.Shutdown(shutdownCtx)
cancel()
}
} }

View File

@ -15,6 +15,7 @@ import (
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
@ -49,7 +50,10 @@ func (m *mockSessionCache) PutToken(key SessionCacheKey, token *Token) {
} }
func TestLogin(t *testing.T) { func TestLogin(t *testing.T) {
time1 := time.Date(3020, 10, 12, 13, 14, 15, 16, time.UTC) time1 := time.Date(2035, 10, 12, 13, 14, 15, 16, time.UTC)
time1Unix := int64(2075807775)
require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix())
testToken := Token{ testToken := Token{
AccessToken: &AccessToken{ AccessToken: &AccessToken{
Token: "test-access-token", Token: "test-access-token",
@ -59,7 +63,9 @@ func TestLogin(t *testing.T) {
Token: "test-refresh-token", Token: "test-refresh-token",
}, },
IDToken: &IDToken{ IDToken: &IDToken{
Token: "test-id-token", // Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/ (using time1Unix from above):
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" --exp 2075807775
Token: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImV4cCI6MjA3NTgwNzc3NSwiaWF0IjoxNjAzMzk5NTY4LCJpc3MiOiJ0ZXN0LWlzc3VlciIsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAzMzk5NTY4LCJzdWIiOiJ0ZXN0LXVzZXIifQ.CdwUWQb6xELeFlC4u84K4rzks7YiDJiXxIo_SaRvCHBijxtil812RBRfPuAyYKJlGwFx1g-JYvkUg69X5NmvmLXkaOdHIKUAT7Nqa7yqd1xOAP9IlFj9qZM3Q7s8gWWW9da-_ryagzN4fyGfNfYeGhzIriSMaVpuBGz1eg6f-6VuuulnoiOpl8A0l50u0MdRjjsxRHuiR2loIhUxoIQQ9xN8w53UiP0R1uz8_uV0_K93RSq37aPjsnCXRLwUUb3azkRVe6B9EUW1ihthQ-KfRaU1iq2rY1m5UqNzf0NqDXCrN5SF-GVxOhKXJTsN4-PABfJBjqxg6dGUGeIa2JhFcA",
Expiry: metav1.NewTime(time1.Add(2 * time.Minute)), Expiry: metav1.NewTime(time1.Add(2 * time.Minute)),
}, },
} }
@ -70,11 +76,15 @@ func TestLogin(t *testing.T) {
})) }))
t.Cleanup(errorServer.Close) t.Cleanup(errorServer.Close)
// Start a test server that returns a real keyset // Start a test server that returns a real keyset and answers refresh requests.
providerMux := http.NewServeMux() providerMux := http.NewServeMux()
successServer := httptest.NewServer(providerMux) successServer := httptest.NewServer(providerMux)
t.Cleanup(successServer.Close) t.Cleanup(successServer.Close)
providerMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { providerMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
return
}
w.Header().Set("content-type", "application/json") w.Header().Set("content-type", "application/json")
type providerJSON struct { type providerJSON struct {
Issuer string `json:"issuer"` Issuer string `json:"issuer"`
@ -89,6 +99,44 @@ func TestLogin(t *testing.T) {
JWKSURL: successServer.URL + "/keys", JWKSURL: successServer.URL + "/keys",
}) })
}) })
providerMux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
return
}
if err := r.ParseForm(); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if r.Form.Get("client_id") != "test-client-id" {
http.Error(w, "expected client_id 'test-client-id'", http.StatusBadRequest)
return
}
if r.Form.Get("grant_type") != "refresh_token" {
http.Error(w, "expected refresh_token grant type", http.StatusBadRequest)
return
}
var response struct {
oauth2.Token
IDToken string `json:"id_token,omitempty"`
ExpiresIn int64 `json:"expires_in"`
}
response.AccessToken = testToken.AccessToken.Token
response.ExpiresIn = int64(time.Until(testToken.AccessToken.Expiry.Time).Seconds())
response.RefreshToken = testToken.RefreshToken.Token
response.IDToken = testToken.IDToken.Token
if r.Form.Get("refresh_token") == "test-refresh-token-returning-invalid-id-token" {
response.IDToken = "not a valid JWT"
} else if r.Form.Get("refresh_token") != "test-refresh-token" {
http.Error(w, "expected refresh_token to be 'test-refresh-token'", http.StatusBadRequest)
return
}
w.Header().Set("content-type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(&response))
})
tests := []struct { tests := []struct {
name string name string
@ -192,6 +240,106 @@ func TestLogin(t *testing.T) {
issuer: errorServer.URL, issuer: errorServer.URL,
wantErr: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL), wantErr: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL),
}, },
{
name: "session cache hit with refreshable token",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
IDToken: &IDToken{
Token: "expired-test-id-token",
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
},
RefreshToken: &RefreshToken{Token: "test-refresh-token"},
}}
t.Cleanup(func() {
cacheKey := SessionCacheKey{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
require.Len(t, cache.sawPutTokens, 1)
require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token)
})
h.cache = cache
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
provider, err := oidc.NewProvider(ctx, iss)
require.NoError(t, err)
return &mockDiscovery{provider: provider}, nil
}
return nil
}
},
wantToken: &testToken,
},
{
name: "session cache hit but refresh returns invalid token",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
IDToken: &IDToken{
Token: "expired-test-id-token",
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
},
RefreshToken: &RefreshToken{Token: "test-refresh-token-returning-invalid-id-token"},
}}
t.Cleanup(func() {
require.Empty(t, cache.sawPutKeys)
require.Empty(t, cache.sawPutTokens)
})
h.cache = cache
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
provider, err := oidc.NewProvider(ctx, iss)
require.NoError(t, err)
return &mockDiscovery{provider: provider}, nil
}
return nil
}
},
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
},
{
name: "session cache hit but refresh fails",
issuer: successServer.URL,
clientID: "not-the-test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
IDToken: &IDToken{
Token: "expired-test-id-token",
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
},
RefreshToken: &RefreshToken{Token: "test-refresh-token"},
}}
t.Cleanup(func() {
require.Empty(t, cache.sawPutKeys)
require.Empty(t, cache.sawPutTokens)
})
h.cache = cache
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
provider, err := oidc.NewProvider(ctx, iss)
require.NoError(t, err)
return &mockDiscovery{provider: provider}, nil
}
h.listenAddr = "invalid-listen-address"
return nil
}
},
// Expect this to fall through to the authorization code flow, so it fails here.
wantErr: "could not open callback listener: listen tcp: address invalid-listen-address: missing port in address",
},
{ {
name: "listen failure", name: "listen failure",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
@ -320,7 +468,30 @@ func TestLogin(t *testing.T) {
require.Nil(t, tok) require.Nil(t, tok)
return return
} }
require.Equal(t, tt.wantToken, tok) require.NoError(t, err)
if tt.wantToken == nil {
require.Nil(t, tok)
return
}
require.NotNil(t, tok)
if want := tt.wantToken.AccessToken; want != nil {
require.NotNil(t, tok.AccessToken)
require.Equal(t, want.Token, tok.AccessToken.Token)
require.Equal(t, want.Type, tok.AccessToken.Type)
requireTimeInDelta(t, want.Expiry.Time, tok.AccessToken.Expiry.Time, 5*time.Second)
} else {
assert.Nil(t, tok.AccessToken)
}
require.Equal(t, tt.wantToken.RefreshToken, tok.RefreshToken)
if want := tt.wantToken.IDToken; want != nil {
require.NotNil(t, tok.IDToken)
require.Equal(t, want.Token, tok.IDToken.Token)
requireTimeInDelta(t, want.Expiry.Time, tok.IDToken.Expiry.Time, 5*time.Second)
} else {
assert.Nil(t, tok.IDToken)
}
}) })
} }
} }
@ -504,3 +675,22 @@ func mockVerifier() *oidc.IDTokenVerifier {
SkipClientIDCheck: true, SkipClientIDCheck: true,
}) })
} }
type mockDiscovery struct{ provider *oidc.Provider }
func (m *mockDiscovery) Endpoint() oauth2.Endpoint { return m.provider.Endpoint() }
func (m *mockDiscovery) Verifier(config *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() }
func requireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Duration) {
require.InDeltaf(t,
float64(t1.UnixNano()),
float64(t2.UnixNano()),
float64(delta.Nanoseconds()),
"expected %s and %s to be < %s apart, but they are %s apart",
t1.Format(time.RFC3339Nano),
t2.Format(time.RFC3339Nano),
delta.String(),
t1.Sub(t2).String(),
)
}