Implement refresh flow in ./internal/oidcclient package.

Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
Matt Moyer 2020-10-22 16:12:02 -05:00
parent 8ae04605ca
commit 3508a28369
No known key found for this signature in database
GPG Key ID: EAE88AD172C5AE2D
2 changed files with 304 additions and 58 deletions

View File

@ -30,6 +30,10 @@ const (
// This is non-zero to ensure that most of the time, your ID token won't expire in the middle of a multi-step k8s
// API operation.
minIDTokenValidity = 10 * time.Minute
// refreshTimeout is the amount of time allotted for OAuth2 refresh operations. Since these don't involve any
// user interaction, they should always be roughly as fast as network latency.
refreshTimeout = 30 * time.Second
)
type handlerState struct {
@ -56,6 +60,7 @@ type handlerState struct {
generatePKCE func() (pkce.Code, error)
generateNonce func() (nonce.Nonce, error)
openURL func(string) error
oidcDiscover func(context.Context, string) (discoveryI, error)
callbacks chan callbackResult
}
@ -123,6 +128,11 @@ type nopCache struct{}
func (*nopCache) GetToken(SessionCacheKey) *Token { return nil }
func (*nopCache) PutToken(SessionCacheKey, *Token) {}
type discoveryI interface {
Endpoint() oauth2.Endpoint
Verifier(*oidc.Config) *oidc.IDTokenVerifier
}
// Login performs an OAuth2/OIDC authorization code login using a localhost listener.
func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
h := handlerState{
@ -140,6 +150,9 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
generateNonce: nonce.Generate,
generatePKCE: pkce.Generate,
openURL: browser.OpenURL,
oidcDiscover: func(ctx context.Context, iss string) (discoveryI, error) {
return oidc.NewProvider(ctx, iss)
},
}
for _, opt := range opts {
if err := opt(&h); err != nil {
@ -177,52 +190,52 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
}
// If the ID token is still valid for a bit, return it immediately and skip the rest of the flow.
if cached := h.cache.GetToken(cacheKey); cached != nil &&
cached.IDToken != nil &&
time.Until(cached.IDToken.Expiry.Time) > minIDTokenValidity {
cached := h.cache.GetToken(cacheKey)
if cached != nil && cached.IDToken != nil && time.Until(cached.IDToken.Expiry.Time) > minIDTokenValidity {
return cached, nil
}
// Perform OIDC discovery.
provider, err := oidc.NewProvider(h.ctx, h.issuer)
discovered, err := h.oidcDiscover(h.ctx, h.issuer)
if err != nil {
return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
}
h.idTokenVerifier = provider.Verifier(&oidc.Config{ClientID: h.clientID})
// Open a TCP listener.
listener, err := net.Listen("tcp", h.listenAddr)
if err != nil {
return nil, fmt.Errorf("could not open callback listener: %w", err)
}
h.idTokenVerifier = discovered.Verifier(&oidc.Config{ClientID: h.clientID})
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
h.oauth2Config = &oauth2.Config{
ClientID: h.clientID,
Endpoint: provider.Endpoint(),
RedirectURL: (&url.URL{
Scheme: "http",
Host: listener.Addr().String(),
Path: h.callbackPath,
}).String(),
Scopes: h.scopes,
Endpoint: discovered.Endpoint(),
Scopes: h.scopes,
}
// Start a callback server in a background goroutine.
mux := http.NewServeMux()
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))
srv := http.Server{
Handler: securityheader.Wrap(mux),
BaseContext: func(_ net.Listener) context.Context { return h.ctx },
// If there was a cached refresh token, attempt to use the refresh flow instead of a fresh login.
if cached != nil && cached.RefreshToken != nil && cached.RefreshToken.Token != "" {
freshToken, err := h.handleRefresh(ctx, cached.RefreshToken)
if err != nil {
return nil, err
}
// If we got a fresh token, we can update the cache and return it. Otherwise we fall through to the full refresh flow.
if freshToken != nil {
h.cache.PutToken(cacheKey, freshToken)
return freshToken, nil
}
}
go func() { _ = srv.Serve(listener) }()
defer func() {
// Gracefully shut down the server, allowing up to 5 seconds for
// clients to receive any in-flight responses.
shutdownCtx, cancel := context.WithTimeout(h.ctx, 1*time.Second)
_ = srv.Shutdown(shutdownCtx)
cancel()
}()
// Open a TCP listener and update the OAuth2 redirect_uri to match (in case we are using an ephemeral port number).
listener, err := net.Listen("tcp", h.listenAddr)
if err != nil {
return nil, fmt.Errorf("could not open callback listener: %w", err)
}
h.oauth2Config.RedirectURL = (&url.URL{
Scheme: "http",
Host: listener.Addr().String(),
Path: h.callbackPath,
}).String()
// Start a callback server in a background goroutine.
shutdown := h.serve(listener)
defer shutdown()
// Open the authorize URL in the users browser.
authorizeURL := h.oauth2Config.AuthCodeURL(
@ -249,6 +262,22 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
}
}
func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *RefreshToken) (*Token, error) {
ctx, cancel := context.WithTimeout(ctx, refreshTimeout)
defer cancel()
refreshSource := h.oauth2Config.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken.Token})
refreshed, err := refreshSource.Token()
if err != nil {
// Ignore errors during refresh, but return nil which will trigger the full login flow.
return nil, nil
}
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at least
// some providers do not include one, so we skip the nonce validation here (but not other validations).
return h.validateToken(ctx, refreshed, false)
}
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {
// If we return an error, also report it back over the channel to the main CLI thread.
defer func() {
@ -280,37 +309,64 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
}
// Perform required validations on the returned ID token.
idTok, hasIDTok := oauth2Tok.Extra("id_token").(string)
if !hasIDTok {
return httperr.New(http.StatusBadRequest, "received response missing ID token")
}
validated, err := h.idTokenVerifier.Verify(r.Context(), idTok)
token, err := h.validateToken(r.Context(), oauth2Tok, true)
if err != nil {
return httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
if validated.AccessTokenHash != "" {
if err := validated.VerifyAccessToken(oauth2Tok.AccessToken); err != nil {
return httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
}
if err := h.nonce.Validate(validated); err != nil {
return httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
return err
}
h.callbacks <- callbackResult{token: &Token{
h.callbacks <- callbackResult{token: token}
_, _ = w.Write([]byte("you have been logged in and may now close this tab"))
return nil
}
func (h *handlerState) validateToken(ctx context.Context, tok *oauth2.Token, checkNonce bool) (*Token, error) {
idTok, hasIDTok := tok.Extra("id_token").(string)
if !hasIDTok {
return nil, httperr.New(http.StatusBadRequest, "received response missing ID token")
}
validated, err := h.idTokenVerifier.Verify(ctx, idTok)
if err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
if validated.AccessTokenHash != "" {
if err := validated.VerifyAccessToken(tok.AccessToken); err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
}
if checkNonce {
if err := h.nonce.Validate(validated); err != nil {
return nil, httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
}
}
return &Token{
AccessToken: &AccessToken{
Token: oauth2Tok.AccessToken,
Type: oauth2Tok.TokenType,
Expiry: metav1.NewTime(oauth2Tok.Expiry),
Token: tok.AccessToken,
Type: tok.TokenType,
Expiry: metav1.NewTime(tok.Expiry),
},
RefreshToken: &RefreshToken{
Token: oauth2Tok.RefreshToken,
Token: tok.RefreshToken,
},
IDToken: &IDToken{
Token: idTok,
Expiry: metav1.NewTime(validated.Expiry),
},
}}
_, _ = w.Write([]byte("you have been logged in and may now close this tab"))
return nil
}, nil
}
func (h *handlerState) serve(listener net.Listener) func() {
mux := http.NewServeMux()
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))
srv := http.Server{
Handler: securityheader.Wrap(mux),
BaseContext: func(_ net.Listener) context.Context { return h.ctx },
}
go func() { _ = srv.Serve(listener) }()
return func() {
// Gracefully shut down the server, allowing up to 5 seconds for
// clients to receive any in-flight responses.
shutdownCtx, cancel := context.WithTimeout(h.ctx, 1*time.Second)
_ = srv.Shutdown(shutdownCtx)
cancel()
}
}

View File

@ -15,6 +15,7 @@ import (
"github.com/coreos/go-oidc"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
@ -49,7 +50,10 @@ func (m *mockSessionCache) PutToken(key SessionCacheKey, token *Token) {
}
func TestLogin(t *testing.T) {
time1 := time.Date(3020, 10, 12, 13, 14, 15, 16, time.UTC)
time1 := time.Date(2035, 10, 12, 13, 14, 15, 16, time.UTC)
time1Unix := int64(2075807775)
require.Equal(t, time1Unix, time1.Add(2*time.Minute).Unix())
testToken := Token{
AccessToken: &AccessToken{
Token: "test-access-token",
@ -59,7 +63,9 @@ func TestLogin(t *testing.T) {
Token: "test-refresh-token",
},
IDToken: &IDToken{
Token: "test-id-token",
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/ (using time1Unix from above):
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti" --exp 2075807775
Token: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImV4cCI6MjA3NTgwNzc3NSwiaWF0IjoxNjAzMzk5NTY4LCJpc3MiOiJ0ZXN0LWlzc3VlciIsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAzMzk5NTY4LCJzdWIiOiJ0ZXN0LXVzZXIifQ.CdwUWQb6xELeFlC4u84K4rzks7YiDJiXxIo_SaRvCHBijxtil812RBRfPuAyYKJlGwFx1g-JYvkUg69X5NmvmLXkaOdHIKUAT7Nqa7yqd1xOAP9IlFj9qZM3Q7s8gWWW9da-_ryagzN4fyGfNfYeGhzIriSMaVpuBGz1eg6f-6VuuulnoiOpl8A0l50u0MdRjjsxRHuiR2loIhUxoIQQ9xN8w53UiP0R1uz8_uV0_K93RSq37aPjsnCXRLwUUb3azkRVe6B9EUW1ihthQ-KfRaU1iq2rY1m5UqNzf0NqDXCrN5SF-GVxOhKXJTsN4-PABfJBjqxg6dGUGeIa2JhFcA",
Expiry: metav1.NewTime(time1.Add(2 * time.Minute)),
},
}
@ -70,11 +76,15 @@ func TestLogin(t *testing.T) {
}))
t.Cleanup(errorServer.Close)
// Start a test server that returns a real keyset
// Start a test server that returns a real keyset and answers refresh requests.
providerMux := http.NewServeMux()
successServer := httptest.NewServer(providerMux)
t.Cleanup(successServer.Close)
providerMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
return
}
w.Header().Set("content-type", "application/json")
type providerJSON struct {
Issuer string `json:"issuer"`
@ -89,6 +99,44 @@ func TestLogin(t *testing.T) {
JWKSURL: successServer.URL + "/keys",
})
})
providerMux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "unexpected method", http.StatusMethodNotAllowed)
return
}
if err := r.ParseForm(); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if r.Form.Get("client_id") != "test-client-id" {
http.Error(w, "expected client_id 'test-client-id'", http.StatusBadRequest)
return
}
if r.Form.Get("grant_type") != "refresh_token" {
http.Error(w, "expected refresh_token grant type", http.StatusBadRequest)
return
}
var response struct {
oauth2.Token
IDToken string `json:"id_token,omitempty"`
ExpiresIn int64 `json:"expires_in"`
}
response.AccessToken = testToken.AccessToken.Token
response.ExpiresIn = int64(time.Until(testToken.AccessToken.Expiry.Time).Seconds())
response.RefreshToken = testToken.RefreshToken.Token
response.IDToken = testToken.IDToken.Token
if r.Form.Get("refresh_token") == "test-refresh-token-returning-invalid-id-token" {
response.IDToken = "not a valid JWT"
} else if r.Form.Get("refresh_token") != "test-refresh-token" {
http.Error(w, "expected refresh_token to be 'test-refresh-token'", http.StatusBadRequest)
return
}
w.Header().Set("content-type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(&response))
})
tests := []struct {
name string
@ -192,6 +240,106 @@ func TestLogin(t *testing.T) {
issuer: errorServer.URL,
wantErr: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL),
},
{
name: "session cache hit with refreshable token",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
IDToken: &IDToken{
Token: "expired-test-id-token",
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
},
RefreshToken: &RefreshToken{Token: "test-refresh-token"},
}}
t.Cleanup(func() {
cacheKey := SessionCacheKey{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
require.Len(t, cache.sawPutTokens, 1)
require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token)
})
h.cache = cache
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
provider, err := oidc.NewProvider(ctx, iss)
require.NoError(t, err)
return &mockDiscovery{provider: provider}, nil
}
return nil
}
},
wantToken: &testToken,
},
{
name: "session cache hit but refresh returns invalid token",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
IDToken: &IDToken{
Token: "expired-test-id-token",
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
},
RefreshToken: &RefreshToken{Token: "test-refresh-token-returning-invalid-id-token"},
}}
t.Cleanup(func() {
require.Empty(t, cache.sawPutKeys)
require.Empty(t, cache.sawPutTokens)
})
h.cache = cache
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
provider, err := oidc.NewProvider(ctx, iss)
require.NoError(t, err)
return &mockDiscovery{provider: provider}, nil
}
return nil
}
},
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
},
{
name: "session cache hit but refresh fails",
issuer: successServer.URL,
clientID: "not-the-test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
IDToken: &IDToken{
Token: "expired-test-id-token",
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
},
RefreshToken: &RefreshToken{Token: "test-refresh-token"},
}}
t.Cleanup(func() {
require.Empty(t, cache.sawPutKeys)
require.Empty(t, cache.sawPutTokens)
})
h.cache = cache
h.oidcDiscover = func(ctx context.Context, iss string) (discoveryI, error) {
provider, err := oidc.NewProvider(ctx, iss)
require.NoError(t, err)
return &mockDiscovery{provider: provider}, nil
}
h.listenAddr = "invalid-listen-address"
return nil
}
},
// Expect this to fall through to the authorization code flow, so it fails here.
wantErr: "could not open callback listener: listen tcp: address invalid-listen-address: missing port in address",
},
{
name: "listen failure",
opt: func(t *testing.T) Option {
@ -320,7 +468,30 @@ func TestLogin(t *testing.T) {
require.Nil(t, tok)
return
}
require.Equal(t, tt.wantToken, tok)
require.NoError(t, err)
if tt.wantToken == nil {
require.Nil(t, tok)
return
}
require.NotNil(t, tok)
if want := tt.wantToken.AccessToken; want != nil {
require.NotNil(t, tok.AccessToken)
require.Equal(t, want.Token, tok.AccessToken.Token)
require.Equal(t, want.Type, tok.AccessToken.Type)
requireTimeInDelta(t, want.Expiry.Time, tok.AccessToken.Expiry.Time, 5*time.Second)
} else {
assert.Nil(t, tok.AccessToken)
}
require.Equal(t, tt.wantToken.RefreshToken, tok.RefreshToken)
if want := tt.wantToken.IDToken; want != nil {
require.NotNil(t, tok.IDToken)
require.Equal(t, want.Token, tok.IDToken.Token)
requireTimeInDelta(t, want.Expiry.Time, tok.IDToken.Expiry.Time, 5*time.Second)
} else {
assert.Nil(t, tok.IDToken)
}
})
}
}
@ -504,3 +675,22 @@ func mockVerifier() *oidc.IDTokenVerifier {
SkipClientIDCheck: true,
})
}
type mockDiscovery struct{ provider *oidc.Provider }
func (m *mockDiscovery) Endpoint() oauth2.Endpoint { return m.provider.Endpoint() }
func (m *mockDiscovery) Verifier(config *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() }
func requireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Duration) {
require.InDeltaf(t,
float64(t1.UnixNano()),
float64(t2.UnixNano()),
float64(delta.Nanoseconds()),
"expected %s and %s to be < %s apart, but they are %s apart",
t1.Format(time.RFC3339Nano),
t2.Format(time.RFC3339Nano),
delta.String(),
t1.Sub(t2).String(),
)
}