Merge branch 'main' into token-refresh

This commit is contained in:
Ryan Richard 2020-12-08 12:32:41 -08:00
commit a9111f39af
4 changed files with 513 additions and 44 deletions

View File

@ -13,6 +13,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"github.com/coreos/go-oidc"
"github.com/spf13/cobra" "github.com/spf13/cobra"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
clientauthenticationv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1" clientauthenticationv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1"
@ -44,15 +45,17 @@ func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oid
sessionCachePath string sessionCachePath string
caBundlePaths []string caBundlePaths []string
debugSessionCache bool debugSessionCache bool
requestAudience string
) )
cmd.Flags().StringVar(&issuer, "issuer", "", "OpenID Connect issuer URL.") cmd.Flags().StringVar(&issuer, "issuer", "", "OpenID Connect issuer URL.")
cmd.Flags().StringVar(&clientID, "client-id", "", "OpenID Connect client ID.") cmd.Flags().StringVar(&clientID, "client-id", "", "OpenID Connect client ID.")
cmd.Flags().Uint16Var(&listenPort, "listen-port", 0, "TCP port for localhost listener (authorization code flow only).") cmd.Flags().Uint16Var(&listenPort, "listen-port", 0, "TCP port for localhost listener (authorization code flow only).")
cmd.Flags().StringSliceVar(&scopes, "scopes", []string{"offline_access", "openid"}, "OIDC scopes to request during login.") cmd.Flags().StringSliceVar(&scopes, "scopes", []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID}, "OIDC scopes to request during login.")
cmd.Flags().BoolVar(&skipBrowser, "skip-browser", false, "Skip opening the browser (just print the URL).") cmd.Flags().BoolVar(&skipBrowser, "skip-browser", false, "Skip opening the browser (just print the URL).")
cmd.Flags().StringVar(&sessionCachePath, "session-cache", filepath.Join(mustGetConfigDir(), "sessions.yaml"), "Path to session cache file.") cmd.Flags().StringVar(&sessionCachePath, "session-cache", filepath.Join(mustGetConfigDir(), "sessions.yaml"), "Path to session cache file.")
cmd.Flags().StringSliceVar(&caBundlePaths, "ca-bundle", nil, "Path to TLS certificate authority bundle (PEM format, optional, can be repeated).") cmd.Flags().StringSliceVar(&caBundlePaths, "ca-bundle", nil, "Path to TLS certificate authority bundle (PEM format, optional, can be repeated).")
cmd.Flags().BoolVar(&debugSessionCache, "debug-session-cache", false, "Print debug logs related to the session cache.") cmd.Flags().BoolVar(&debugSessionCache, "debug-session-cache", false, "Print debug logs related to the session cache.")
cmd.Flags().StringVar(&requestAudience, "request-audience", "", "Request a token with an alternate audience using RF8693 token exchange.")
mustMarkHidden(&cmd, "debug-session-cache") mustMarkHidden(&cmd, "debug-session-cache")
mustMarkRequired(&cmd, "issuer", "client-id") mustMarkRequired(&cmd, "issuer", "client-id")
@ -80,6 +83,10 @@ func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oid
opts = append(opts, oidcclient.WithListenPort(listenPort)) opts = append(opts, oidcclient.WithListenPort(listenPort))
} }
if requestAudience != "" {
opts = append(opts, oidcclient.WithRequestAudience(requestAudience))
}
// --skip-browser replaces the default "browser open" function with one that prints to stderr. // --skip-browser replaces the default "browser open" function with one that prints to stderr.
if skipBrowser { if skipBrowser {
opts = append(opts, oidcclient.WithBrowserOpen(func(url string) error { opts = append(opts, oidcclient.WithBrowserOpen(func(url string) error {

View File

@ -46,6 +46,7 @@ func TestLoginOIDCCommand(t *testing.T) {
-h, --help help for oidc -h, --help help for oidc
--issuer string OpenID Connect issuer URL. --issuer string OpenID Connect issuer URL.
--listen-port uint16 TCP port for localhost listener (authorization code flow only). --listen-port uint16 TCP port for localhost listener (authorization code flow only).
--request-audience string Request a token with an alternate audience using RF8693 token exchange.
--scopes strings OIDC scopes to request during login. (default [offline_access,openid]) --scopes strings OIDC scopes to request during login. (default [offline_access,openid])
--session-cache string Path to session cache file. (default "` + cfgDir + `/sessions.yaml") --session-cache string Path to session cache file. (default "` + cfgDir + `/sessions.yaml")
--skip-browser Skip opening the browser (just print the URL). --skip-browser Skip opening the browser (just print the URL).

View File

@ -6,16 +6,19 @@ package oidcclient
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"sort" "sort"
"strings"
"time" "time"
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"github.com/pkg/browser" "github.com/pkg/browser"
"golang.org/x/oauth2" "golang.org/x/oauth2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/httputil/securityheader" "go.pinniped.dev/internal/httputil/securityheader"
@ -46,6 +49,8 @@ type handlerState struct {
scopes []string scopes []string
cache SessionCache cache SessionCache
requestedAudience string
httpClient *http.Client httpClient *http.Client
// Parameters of the localhost listener. // Parameters of the localhost listener.
@ -65,6 +70,7 @@ type handlerState struct {
generateNonce func() (nonce.Nonce, error) generateNonce func() (nonce.Nonce, error)
openURL func(string) error openURL func(string) error
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
callbacks chan callbackResult callbacks chan callbackResult
} }
@ -148,6 +154,14 @@ func WithClient(httpClient *http.Client) Option {
} }
} }
// WithRequestAudience causes the login flow to perform an additional token exchange using the RFC8693 STS flow.
func WithRequestAudience(audience string) Option {
return func(h *handlerState) error {
h.requestedAudience = audience
return nil
}
}
// nopCache is a SessionCache that doesn't actually do anything. // nopCache is a SessionCache that doesn't actually do anything.
type nopCache struct{} type nopCache struct{}
@ -160,7 +174,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
issuer: issuer, issuer: issuer,
clientID: clientID, clientID: clientID,
listenAddr: "localhost:0", listenAddr: "localhost:0",
scopes: []string{"offline_access", "openid", "email", "profile"}, scopes: []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID, "email", "profile"},
cache: &nopCache{}, cache: &nopCache{},
callbackPath: "/callback", callbackPath: "/callback",
ctx: context.Background(), ctx: context.Background(),
@ -173,6 +187,9 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
generatePKCE: pkce.Generate, generatePKCE: pkce.Generate,
openURL: browser.OpenURL, openURL: browser.OpenURL,
getProvider: upstreamoidc.New, getProvider: upstreamoidc.New,
validateIDToken: func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) {
return provider.Verifier(&oidc.Config{ClientID: audience}).Verify(ctx, token)
},
} }
for _, opt := range opts { for _, opt := range opts {
if err := opt(&h); err != nil { if err := opt(&h); err != nil {
@ -201,6 +218,26 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
return nil, err return nil, err
} }
// Do the basic login to get an access and ID token issued to our main client ID.
baseToken, err := h.baseLogin()
if err != nil {
return nil, err
}
// If there is no requested audience, or the requested audience matches the one we got, we're done.
if h.requestedAudience == "" || (baseToken.IDToken != nil && h.requestedAudience == baseToken.IDToken.Claims["aud"]) {
return baseToken, err
}
// Perform the RFC8693 token exchange.
exchangedToken, err := h.tokenExchangeRFC8693(baseToken)
if err != nil {
return nil, fmt.Errorf("failed to exchange token: %w", err)
}
return exchangedToken, nil
}
func (h *handlerState) baseLogin() (*oidctypes.Token, error) {
// Check the cache for a previous session issued with the same parameters. // Check the cache for a previous session issued with the same parameters.
sort.Strings(h.scopes) sort.Strings(h.scopes)
cacheKey := SessionCacheKey{ cacheKey := SessionCacheKey{
@ -217,21 +254,13 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
} }
// Perform OIDC discovery. // Perform OIDC discovery.
h.provider, err = oidc.NewProvider(h.ctx, h.issuer) if err := h.initOIDCDiscovery(); err != nil {
if err != nil { return nil, err
return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
}
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
h.oauth2Config = &oauth2.Config{
ClientID: h.clientID,
Endpoint: h.provider.Endpoint(),
Scopes: h.scopes,
} }
// If there was a cached refresh token, attempt to use the refresh flow instead of a fresh login. // If there was a cached refresh token, attempt to use the refresh flow instead of a fresh login.
if cached != nil && cached.RefreshToken != nil && cached.RefreshToken.Token != "" { if cached != nil && cached.RefreshToken != nil && cached.RefreshToken.Token != "" {
freshToken, err := h.handleRefresh(ctx, cached.RefreshToken) freshToken, err := h.handleRefresh(h.ctx, cached.RefreshToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -282,6 +311,95 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
} }
} }
func (h *handlerState) initOIDCDiscovery() error {
// Make this method idempotent so it can be called in multiple cases with no extra network requests.
if h.provider != nil {
return nil
}
var err error
h.provider, err = oidc.NewProvider(h.ctx, h.issuer)
if err != nil {
return fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
}
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
h.oauth2Config = &oauth2.Config{
ClientID: h.clientID,
Endpoint: h.provider.Endpoint(),
Scopes: h.scopes,
}
return nil
}
func (h *handlerState) tokenExchangeRFC8693(baseToken *oidctypes.Token) (*oidctypes.Token, error) {
// Perform OIDC discovery. This may have already been performed if there was not a cached base token.
if err := h.initOIDCDiscovery(); err != nil {
return nil, err
}
// Use the base access token to authenticate our request. This will populate the "authorization" header.
client := oauth2.NewClient(h.ctx, oauth2.StaticTokenSource(&oauth2.Token{AccessToken: baseToken.AccessToken.Token}))
// Form the HTTP POST request with the parameters specified by RFC8693.
reqBody := strings.NewReader(url.Values{
"grant_type": []string{"urn:ietf:params:oauth:grant-type:token-exchange"},
"audience": []string{h.requestedAudience},
"subject_token": []string{baseToken.AccessToken.Token},
"subject_token_type": []string{"urn:ietf:params:oauth:token-type:access_token"},
"requested_token_type": []string{"urn:ietf:params:oauth:token-type:jwt"},
}.Encode())
req, err := http.NewRequestWithContext(h.ctx, http.MethodPost, h.oauth2Config.Endpoint.TokenURL, reqBody)
if err != nil {
return nil, fmt.Errorf("could not build RFC8693 request: %w", err)
}
req.Header.Set("content-type", "application/x-www-form-urlencoded")
// Perform the request.
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
// Expect an HTTP 200 response with "application/json" content type.
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected HTTP response status %d", resp.StatusCode)
}
if contentType := resp.Header.Get("content-type"); contentType != "application/json" {
return nil, fmt.Errorf("unexpected HTTP response content type %q", contentType)
}
// Decode the JSON response body.
var respBody struct {
AccessToken string `json:"access_token"`
IssuedTokenType string `json:"issued_token_type"`
TokenType string `json:"token_type"`
}
if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
// Expect the token_type and issued_token_type response parameters to have some known values.
if respBody.TokenType != "N_A" {
return nil, fmt.Errorf("got unexpected token_type %q", respBody.TokenType)
}
if respBody.IssuedTokenType != "urn:ietf:params:oauth:token-type:jwt" {
return nil, fmt.Errorf("got unexpected issued_token_type %q", respBody.IssuedTokenType)
}
// Validate the returned JWT to make sure we got the audience we wanted and extract the expiration time.
stsToken, err := h.validateIDToken(h.ctx, h.provider, h.requestedAudience, respBody.AccessToken)
if err != nil {
return nil, fmt.Errorf("received invalid JWT: %w", err)
}
return &oidctypes.Token{IDToken: &oidctypes.IDToken{
Token: respBody.AccessToken,
Expiry: metav1.NewTime(stsToken.Expiry),
}}, nil
}
func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctypes.RefreshToken) (*oidctypes.Token, error) { func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctypes.RefreshToken) (*oidctypes.Token, error) {
ctx, cancel := context.WithTimeout(ctx, refreshTimeout) ctx, cancel := context.WithTimeout(ctx, refreshTimeout)
defer cancel() defer cancel()

View File

@ -62,12 +62,36 @@ func TestLogin(t *testing.T) {
IDToken: &oidctypes.IDToken{Token: "test-id-token", Expiry: metav1.NewTime(time1.Add(2 * time.Minute))}, IDToken: &oidctypes.IDToken{Token: "test-id-token", Expiry: metav1.NewTime(time1.Add(2 * time.Minute))},
} }
testExchangedToken := oidctypes.Token{
IDToken: &oidctypes.IDToken{Token: "test-id-token-with-requested-audience", Expiry: metav1.NewTime(time1.Add(3 * time.Minute))},
}
// Start a test server that returns 500 errors // Start a test server that returns 500 errors
errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "some discovery error", http.StatusInternalServerError) http.Error(w, "some discovery error", http.StatusInternalServerError)
})) }))
t.Cleanup(errorServer.Close) t.Cleanup(errorServer.Close)
// Start a test server that returns discovery data with a broken token URL
brokenTokenURLMux := http.NewServeMux()
brokenTokenURLServer := httptest.NewServer(brokenTokenURLMux)
brokenTokenURLMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "application/json")
type providerJSON struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
}
_ = json.NewEncoder(w).Encode(&providerJSON{
Issuer: brokenTokenURLServer.URL,
AuthURL: brokenTokenURLServer.URL + "/authorize",
TokenURL: "%",
JWKSURL: brokenTokenURLServer.URL + "/keys",
})
})
t.Cleanup(brokenTokenURLServer.Close)
// Start a test server that returns a real discovery document and answers refresh requests. // Start a test server that returns a real discovery document and answers refresh requests.
providerMux := http.NewServeMux() providerMux := http.NewServeMux()
successServer := httptest.NewServer(providerMux) successServer := httptest.NewServer(providerMux)
@ -100,20 +124,21 @@ func TestLogin(t *testing.T) {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
if r.Form.Get("client_id") != "test-client-id" {
http.Error(w, "expected client_id 'test-client-id'", http.StatusBadRequest)
return
}
if r.Form.Get("grant_type") != "refresh_token" {
http.Error(w, "expected refresh_token grant type", http.StatusBadRequest)
return
}
var response struct { var response struct {
oauth2.Token oauth2.Token
IDToken string `json:"id_token,omitempty"` IDToken string `json:"id_token,omitempty"`
ExpiresIn int64 `json:"expires_in"` ExpiresIn int64 `json:"expires_in"`
IssuedTokenType string `json:"issued_token_type,omitempty"`
} }
switch r.Form.Get("grant_type") {
case "refresh_token":
if r.Form.Get("client_id") != "test-client-id" {
http.Error(w, "expected client_id 'test-client-id'", http.StatusBadRequest)
return
}
response.AccessToken = testToken.AccessToken.Token response.AccessToken = testToken.AccessToken.Token
response.ExpiresIn = int64(time.Until(testToken.AccessToken.Expiry.Time).Seconds()) response.ExpiresIn = int64(time.Until(testToken.AccessToken.Expiry.Time).Seconds())
response.RefreshToken = testToken.RefreshToken.Token response.RefreshToken = testToken.RefreshToken.Token
@ -126,6 +151,41 @@ func TestLogin(t *testing.T) {
return return
} }
case "urn:ietf:params:oauth:grant-type:token-exchange":
switch r.Form.Get("audience") {
case "test-audience-produce-invalid-http-response":
http.Redirect(w, r, "%", http.StatusTemporaryRedirect)
return
case "test-audience-produce-http-400":
http.Error(w, "some server error", http.StatusBadRequest)
return
case "test-audience-produce-wrong-content-type":
w.Header().Set("content-type", "invalid")
return
case "test-audience-produce-invalid-json":
w.Header().Set("content-type", "application/json")
_, _ = w.Write([]byte(`{`))
return
case "test-audience-produce-invalid-tokentype":
response.TokenType = "invalid"
case "test-audience-produce-invalid-issuedtokentype":
response.TokenType = "N_A"
response.IssuedTokenType = "invalid"
case "test-audience-produce-invalid-jwt":
response.TokenType = "N_A"
response.IssuedTokenType = "urn:ietf:params:oauth:token-type:jwt"
response.AccessToken = "not-a-valid-jwt"
default:
response.TokenType = "N_A"
response.IssuedTokenType = "urn:ietf:params:oauth:token-type:jwt"
response.AccessToken = testExchangedToken.IDToken.Token
}
default:
http.Error(w, fmt.Sprintf("invalid grant_type %q", r.Form.Get("grant_type")), http.StatusBadRequest)
return
}
w.Header().Set("content-type", "application/json") w.Header().Set("content-type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(&response)) require.NoError(t, json.NewEncoder(w).Encode(&response))
}) })
@ -444,6 +504,289 @@ func TestLogin(t *testing.T) {
issuer: successServer.URL, issuer: successServer.URL,
wantToken: &testToken, wantToken: &testToken,
}, },
{
name: "with requested audience, session cache hit with valid token, but discovery fails",
clientID: "test-client-id",
issuer: errorServer.URL,
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: errorServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("cluster-1234")(h))
return nil
}
},
wantErr: fmt.Sprintf("failed to exchange token: could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL),
},
{
name: "with requested audience, session cache hit with valid token, but token URL is invalid",
issuer: brokenTokenURLServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: brokenTokenURLServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("cluster-1234")(h))
return nil
}
},
wantErr: `failed to exchange token: could not build RFC8693 request: parse "%": invalid URL escape "%"`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request fails",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-http-response")(h))
return nil
}
},
wantErr: fmt.Sprintf(`failed to exchange token: Post "%s/token": failed to parse Location header "%%": parse "%%": invalid URL escape "%%"`, successServer.URL),
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns non-200",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-http-400")(h))
return nil
}
},
wantErr: `failed to exchange token: unexpected HTTP response status 400`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns wrong content-type",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-wrong-content-type")(h))
return nil
}
},
wantErr: `failed to exchange token: unexpected HTTP response content type "invalid"`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid JSON",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-json")(h))
return nil
}
},
wantErr: `failed to exchange token: failed to decode response: unexpected EOF`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid token_type",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-tokentype")(h))
return nil
}
},
wantErr: `failed to exchange token: got unexpected token_type "invalid"`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid issued_token_type",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-issuedtokentype")(h))
return nil
}
},
wantErr: `failed to exchange token: got unexpected issued_token_type "invalid"`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid JWT",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-jwt")(h))
return nil
}
},
wantErr: `failed to exchange token: received invalid JWT: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts`,
},
{
name: "with requested audience, session cache hit with valid token, and token exchange request succeeds",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience")(h))
h.validateIDToken = func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) {
require.Equal(t, "test-audience", audience)
require.Equal(t, "test-id-token-with-requested-audience", token)
return &oidc.IDToken{Expiry: testExchangedToken.IDToken.Expiry.Time}, nil
}
return nil
}
},
wantToken: &testExchangedToken,
},
{
name: "with requested audience, session cache hit with valid refresh token, and token exchange request succeeds",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
IDToken: &oidctypes.IDToken{
Token: "expired-test-id-token",
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
},
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
}}
t.Cleanup(func() {
cacheKey := SessionCacheKey{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
require.Len(t, cache.sawPutTokens, 1)
require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token)
})
h.cache = cache
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
Return(&testToken, nil)
return mock
}
require.NoError(t, WithRequestAudience("test-audience")(h))
h.validateIDToken = func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) {
require.Equal(t, "test-audience", audience)
require.Equal(t, "test-id-token-with-requested-audience", token)
return &oidc.IDToken{Expiry: testExchangedToken.IDToken.Expiry.Time}, nil
}
return nil
}
},
wantToken: &testExchangedToken,
},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt