Add a --request-audience flag to the pinniped login oidc CLI command.

Signed-off-by: Matt Moyer <moyerm@vmware.com>
Signed-off-by: Margo Crawford <margaretc@vmware.com>
This commit is contained in:
Matt Moyer 2020-12-04 17:33:53 -06:00
parent 6a90a10123
commit bfcd2569e9
No known key found for this signature in database
GPG Key ID: EAE88AD172C5AE2D
4 changed files with 510 additions and 42 deletions

View File

@ -44,6 +44,7 @@ func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oid
sessionCachePath string sessionCachePath string
caBundlePaths []string caBundlePaths []string
debugSessionCache bool debugSessionCache bool
requestAudience string
) )
cmd.Flags().StringVar(&issuer, "issuer", "", "OpenID Connect issuer URL.") cmd.Flags().StringVar(&issuer, "issuer", "", "OpenID Connect issuer URL.")
cmd.Flags().StringVar(&clientID, "client-id", "", "OpenID Connect client ID.") cmd.Flags().StringVar(&clientID, "client-id", "", "OpenID Connect client ID.")
@ -53,6 +54,7 @@ func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oid
cmd.Flags().StringVar(&sessionCachePath, "session-cache", filepath.Join(mustGetConfigDir(), "sessions.yaml"), "Path to session cache file.") cmd.Flags().StringVar(&sessionCachePath, "session-cache", filepath.Join(mustGetConfigDir(), "sessions.yaml"), "Path to session cache file.")
cmd.Flags().StringSliceVar(&caBundlePaths, "ca-bundle", nil, "Path to TLS certificate authority bundle (PEM format, optional, can be repeated).") cmd.Flags().StringSliceVar(&caBundlePaths, "ca-bundle", nil, "Path to TLS certificate authority bundle (PEM format, optional, can be repeated).")
cmd.Flags().BoolVar(&debugSessionCache, "debug-session-cache", false, "Print debug logs related to the session cache.") cmd.Flags().BoolVar(&debugSessionCache, "debug-session-cache", false, "Print debug logs related to the session cache.")
cmd.Flags().StringVar(&requestAudience, "request-audience", "", "Request a token with an alternate audience using RF8693 token exchange.")
mustMarkHidden(&cmd, "debug-session-cache") mustMarkHidden(&cmd, "debug-session-cache")
mustMarkRequired(&cmd, "issuer", "client-id") mustMarkRequired(&cmd, "issuer", "client-id")
@ -80,6 +82,10 @@ func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oid
opts = append(opts, oidcclient.WithListenPort(listenPort)) opts = append(opts, oidcclient.WithListenPort(listenPort))
} }
if requestAudience != "" {
opts = append(opts, oidcclient.WithRequestAudience(requestAudience))
}
// --skip-browser replaces the default "browser open" function with one that prints to stderr. // --skip-browser replaces the default "browser open" function with one that prints to stderr.
if skipBrowser { if skipBrowser {
opts = append(opts, oidcclient.WithBrowserOpen(func(url string) error { opts = append(opts, oidcclient.WithBrowserOpen(func(url string) error {

View File

@ -41,14 +41,15 @@ func TestLoginOIDCCommand(t *testing.T) {
oidc --issuer ISSUER --client-id CLIENT_ID [flags] oidc --issuer ISSUER --client-id CLIENT_ID [flags]
Flags: Flags:
--ca-bundle strings Path to TLS certificate authority bundle (PEM format, optional, can be repeated). --ca-bundle strings Path to TLS certificate authority bundle (PEM format, optional, can be repeated).
--client-id string OpenID Connect client ID. --client-id string OpenID Connect client ID.
-h, --help help for oidc -h, --help help for oidc
--issuer string OpenID Connect issuer URL. --issuer string OpenID Connect issuer URL.
--listen-port uint16 TCP port for localhost listener (authorization code flow only). --listen-port uint16 TCP port for localhost listener (authorization code flow only).
--scopes strings OIDC scopes to request during login. (default [offline_access,openid]) --request-audience string Request a token with an alternate audience using RF8693 token exchange.
--session-cache string Path to session cache file. (default "` + cfgDir + `/sessions.yaml") --scopes strings OIDC scopes to request during login. (default [offline_access,openid])
--skip-browser Skip opening the browser (just print the URL). --session-cache string Path to session cache file. (default "` + cfgDir + `/sessions.yaml")
--skip-browser Skip opening the browser (just print the URL).
`), `),
}, },
{ {

View File

@ -6,16 +6,19 @@ package oidcclient
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"sort" "sort"
"strings"
"time" "time"
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"github.com/pkg/browser" "github.com/pkg/browser"
"golang.org/x/oauth2" "golang.org/x/oauth2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/httputil/securityheader" "go.pinniped.dev/internal/httputil/securityheader"
@ -46,6 +49,8 @@ type handlerState struct {
scopes []string scopes []string
cache SessionCache cache SessionCache
requestedAudience string
httpClient *http.Client httpClient *http.Client
// Parameters of the localhost listener. // Parameters of the localhost listener.
@ -60,11 +65,12 @@ type handlerState struct {
pkce pkce.Code pkce pkce.Code
// External calls for things. // External calls for things.
generateState func() (state.State, error) generateState func() (state.State, error)
generatePKCE func() (pkce.Code, error) generatePKCE func() (pkce.Code, error)
generateNonce func() (nonce.Nonce, error) generateNonce func() (nonce.Nonce, error)
openURL func(string) error openURL func(string) error
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
callbacks chan callbackResult callbacks chan callbackResult
} }
@ -148,6 +154,14 @@ func WithClient(httpClient *http.Client) Option {
} }
} }
// WithRequestAudience causes the login flow to perform an additional token exchange using the RFC8693 STS flow.
func WithRequestAudience(audience string) Option {
return func(h *handlerState) error {
h.requestedAudience = audience
return nil
}
}
// nopCache is a SessionCache that doesn't actually do anything. // nopCache is a SessionCache that doesn't actually do anything.
type nopCache struct{} type nopCache struct{}
@ -173,6 +187,9 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
generatePKCE: pkce.Generate, generatePKCE: pkce.Generate,
openURL: browser.OpenURL, openURL: browser.OpenURL,
getProvider: upstreamoidc.New, getProvider: upstreamoidc.New,
validateIDToken: func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) {
return provider.Verifier(&oidc.Config{ClientID: audience}).Verify(ctx, token)
},
} }
for _, opt := range opts { for _, opt := range opts {
if err := opt(&h); err != nil { if err := opt(&h); err != nil {
@ -201,6 +218,26 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
return nil, err return nil, err
} }
// Do the basic login to get an access and ID token issued to our main client ID.
baseToken, err := h.baseLogin()
if err != nil {
return nil, err
}
// If there is no requested audience, or the requested audience matches the one we got, we're done.
if h.requestedAudience == "" || (baseToken.IDToken != nil && h.requestedAudience == baseToken.IDToken.Claims["aud"]) {
return baseToken, err
}
// Perform the RFC8693 token exchange.
exchangedToken, err := h.tokenExchangeRFC8693(baseToken)
if err != nil {
return nil, fmt.Errorf("failed to exchange token: %w", err)
}
return exchangedToken, nil
}
func (h *handlerState) baseLogin() (*oidctypes.Token, error) {
// Check the cache for a previous session issued with the same parameters. // Check the cache for a previous session issued with the same parameters.
sort.Strings(h.scopes) sort.Strings(h.scopes)
cacheKey := SessionCacheKey{ cacheKey := SessionCacheKey{
@ -217,21 +254,13 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
} }
// Perform OIDC discovery. // Perform OIDC discovery.
h.provider, err = oidc.NewProvider(h.ctx, h.issuer) if err := h.initOIDCDiscovery(); err != nil {
if err != nil { return nil, err
return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
}
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
h.oauth2Config = &oauth2.Config{
ClientID: h.clientID,
Endpoint: h.provider.Endpoint(),
Scopes: h.scopes,
} }
// If there was a cached refresh token, attempt to use the refresh flow instead of a fresh login. // If there was a cached refresh token, attempt to use the refresh flow instead of a fresh login.
if cached != nil && cached.RefreshToken != nil && cached.RefreshToken.Token != "" { if cached != nil && cached.RefreshToken != nil && cached.RefreshToken.Token != "" {
freshToken, err := h.handleRefresh(ctx, cached.RefreshToken) freshToken, err := h.handleRefresh(h.ctx, cached.RefreshToken)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -282,6 +311,95 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
} }
} }
func (h *handlerState) initOIDCDiscovery() error {
// Make this method idempotent so it can be called in multiple cases with no extra network requests.
if h.provider != nil {
return nil
}
var err error
h.provider, err = oidc.NewProvider(h.ctx, h.issuer)
if err != nil {
return fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
}
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
h.oauth2Config = &oauth2.Config{
ClientID: h.clientID,
Endpoint: h.provider.Endpoint(),
Scopes: h.scopes,
}
return nil
}
func (h *handlerState) tokenExchangeRFC8693(baseToken *oidctypes.Token) (*oidctypes.Token, error) {
// Perform OIDC discovery. This may have already been performed if there was not a cached base token.
if err := h.initOIDCDiscovery(); err != nil {
return nil, err
}
// Use the base access token to authenticate our request. This will populate the "authorization" header.
client := oauth2.NewClient(h.ctx, oauth2.StaticTokenSource(&oauth2.Token{AccessToken: baseToken.AccessToken.Token}))
// Form the HTTP POST request with the parameters specified by RFC8693.
reqBody := strings.NewReader(url.Values{
"grant_type": []string{"urn:ietf:params:oauth:grant-type:token-exchange"},
"audience": []string{h.requestedAudience},
"subject_token": []string{baseToken.AccessToken.Token},
"subject_token_type": []string{"urn:ietf:params:oauth:token-type:access_token"},
"requested_token_type": []string{"urn:ietf:params:oauth:token-type:jwt"},
}.Encode())
req, err := http.NewRequestWithContext(h.ctx, http.MethodPost, h.oauth2Config.Endpoint.TokenURL, reqBody)
if err != nil {
return nil, fmt.Errorf("could not build RFC8693 request: %w", err)
}
req.Header.Set("content-type", "application/x-www-form-urlencoded")
// Perform the request.
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer func() { _ = resp.Body.Close() }()
// Expect an HTTP 200 response with "application/json" content type.
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected HTTP response status %d", resp.StatusCode)
}
if contentType := resp.Header.Get("content-type"); contentType != "application/json" {
return nil, fmt.Errorf("unexpected HTTP response content type %q", contentType)
}
// Decode the JSON response body.
var respBody struct {
AccessToken string `json:"access_token"`
IssuedTokenType string `json:"issued_token_type"`
TokenType string `json:"token_type"`
}
if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
// Expect the token_type and issued_token_type response parameters to have some known values.
if respBody.TokenType != "N_A" {
return nil, fmt.Errorf("got unexpected token_type %q", respBody.TokenType)
}
if respBody.IssuedTokenType != "urn:ietf:params:oauth:token-type:jwt" {
return nil, fmt.Errorf("got unexpected issued_token_type %q", respBody.IssuedTokenType)
}
// Validate the returned JWT to make sure we got the audience we wanted and extract the expiration time.
stsToken, err := h.validateIDToken(h.ctx, h.provider, h.requestedAudience, respBody.AccessToken)
if err != nil {
return nil, fmt.Errorf("received invalid JWT: %w", err)
}
return &oidctypes.Token{IDToken: &oidctypes.IDToken{
Token: respBody.AccessToken,
Expiry: metav1.NewTime(stsToken.Expiry),
}}, nil
}
func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctypes.RefreshToken) (*oidctypes.Token, error) { func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctypes.RefreshToken) (*oidctypes.Token, error) {
ctx, cancel := context.WithTimeout(ctx, refreshTimeout) ctx, cancel := context.WithTimeout(ctx, refreshTimeout)
defer cancel() defer cancel()

View File

@ -62,12 +62,36 @@ func TestLogin(t *testing.T) {
IDToken: &oidctypes.IDToken{Token: "test-id-token", Expiry: metav1.NewTime(time1.Add(2 * time.Minute))}, IDToken: &oidctypes.IDToken{Token: "test-id-token", Expiry: metav1.NewTime(time1.Add(2 * time.Minute))},
} }
testExchangedToken := oidctypes.Token{
IDToken: &oidctypes.IDToken{Token: "test-id-token-with-requested-audience", Expiry: metav1.NewTime(time1.Add(3 * time.Minute))},
}
// Start a test server that returns 500 errors // Start a test server that returns 500 errors
errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "some discovery error", http.StatusInternalServerError) http.Error(w, "some discovery error", http.StatusInternalServerError)
})) }))
t.Cleanup(errorServer.Close) t.Cleanup(errorServer.Close)
// Start a test server that returns discovery data with a broken token URL
brokenTokenURLMux := http.NewServeMux()
brokenTokenURLServer := httptest.NewServer(brokenTokenURLMux)
brokenTokenURLMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "application/json")
type providerJSON struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
}
_ = json.NewEncoder(w).Encode(&providerJSON{
Issuer: brokenTokenURLServer.URL,
AuthURL: brokenTokenURLServer.URL + "/authorize",
TokenURL: "%",
JWKSURL: brokenTokenURLServer.URL + "/keys",
})
})
t.Cleanup(brokenTokenURLServer.Close)
// Start a test server that returns a real discovery document and answers refresh requests. // Start a test server that returns a real discovery document and answers refresh requests.
providerMux := http.NewServeMux() providerMux := http.NewServeMux()
successServer := httptest.NewServer(providerMux) successServer := httptest.NewServer(providerMux)
@ -100,29 +124,65 @@ func TestLogin(t *testing.T) {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
if r.Form.Get("client_id") != "test-client-id" {
http.Error(w, "expected client_id 'test-client-id'", http.StatusBadRequest)
return
}
if r.Form.Get("grant_type") != "refresh_token" {
http.Error(w, "expected refresh_token grant type", http.StatusBadRequest)
return
}
var response struct { var response struct {
oauth2.Token oauth2.Token
IDToken string `json:"id_token,omitempty"` IDToken string `json:"id_token,omitempty"`
ExpiresIn int64 `json:"expires_in"` ExpiresIn int64 `json:"expires_in"`
IssuedTokenType string `json:"issued_token_type,omitempty"`
} }
response.AccessToken = testToken.AccessToken.Token
response.ExpiresIn = int64(time.Until(testToken.AccessToken.Expiry.Time).Seconds())
response.RefreshToken = testToken.RefreshToken.Token
response.IDToken = testToken.IDToken.Token
if r.Form.Get("refresh_token") == "test-refresh-token-returning-invalid-id-token" { switch r.Form.Get("grant_type") {
response.IDToken = "not a valid JWT" case "refresh_token":
} else if r.Form.Get("refresh_token") != "test-refresh-token" { if r.Form.Get("client_id") != "test-client-id" {
http.Error(w, "expected refresh_token to be 'test-refresh-token'", http.StatusBadRequest) http.Error(w, "expected client_id 'test-client-id'", http.StatusBadRequest)
return
}
response.AccessToken = testToken.AccessToken.Token
response.ExpiresIn = int64(time.Until(testToken.AccessToken.Expiry.Time).Seconds())
response.RefreshToken = testToken.RefreshToken.Token
response.IDToken = testToken.IDToken.Token
if r.Form.Get("refresh_token") == "test-refresh-token-returning-invalid-id-token" {
response.IDToken = "not a valid JWT"
} else if r.Form.Get("refresh_token") != "test-refresh-token" {
http.Error(w, "expected refresh_token to be 'test-refresh-token'", http.StatusBadRequest)
return
}
case "urn:ietf:params:oauth:grant-type:token-exchange":
switch r.Form.Get("audience") {
case "test-audience-produce-invalid-http-response":
http.Redirect(w, r, "%", http.StatusTemporaryRedirect)
return
case "test-audience-produce-http-400":
http.Error(w, "some server error", http.StatusBadRequest)
return
case "test-audience-produce-wrong-content-type":
w.Header().Set("content-type", "invalid")
return
case "test-audience-produce-invalid-json":
w.Header().Set("content-type", "application/json")
_, _ = w.Write([]byte(`{`))
return
case "test-audience-produce-invalid-tokentype":
response.TokenType = "invalid"
case "test-audience-produce-invalid-issuedtokentype":
response.TokenType = "N_A"
response.IssuedTokenType = "invalid"
case "test-audience-produce-invalid-jwt":
response.TokenType = "N_A"
response.IssuedTokenType = "urn:ietf:params:oauth:token-type:jwt"
response.AccessToken = "not-a-valid-jwt"
default:
response.TokenType = "N_A"
response.IssuedTokenType = "urn:ietf:params:oauth:token-type:jwt"
response.AccessToken = testExchangedToken.IDToken.Token
}
default:
http.Error(w, fmt.Sprintf("invalid grant_type %q", r.Form.Get("grant_type")), http.StatusBadRequest)
return return
} }
@ -444,6 +504,289 @@ func TestLogin(t *testing.T) {
issuer: successServer.URL, issuer: successServer.URL,
wantToken: &testToken, wantToken: &testToken,
}, },
{
name: "with requested audience, session cache hit with valid token, but discovery fails",
clientID: "test-client-id",
issuer: errorServer.URL,
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: errorServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("cluster-1234")(h))
return nil
}
},
wantErr: fmt.Sprintf("failed to exchange token: could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL),
},
{
name: "with requested audience, session cache hit with valid token, but token URL is invalid",
issuer: brokenTokenURLServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: brokenTokenURLServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("cluster-1234")(h))
return nil
}
},
wantErr: `failed to exchange token: could not build RFC8693 request: parse "%": invalid URL escape "%"`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request fails",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-http-response")(h))
return nil
}
},
wantErr: fmt.Sprintf(`failed to exchange token: Post "%s/token": failed to parse Location header "%%": parse "%%": invalid URL escape "%%"`, successServer.URL),
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns non-200",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-http-400")(h))
return nil
}
},
wantErr: `failed to exchange token: unexpected HTTP response status 400`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns wrong content-type",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-wrong-content-type")(h))
return nil
}
},
wantErr: `failed to exchange token: unexpected HTTP response content type "invalid"`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid JSON",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-json")(h))
return nil
}
},
wantErr: `failed to exchange token: failed to decode response: unexpected EOF`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid token_type",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-tokentype")(h))
return nil
}
},
wantErr: `failed to exchange token: got unexpected token_type "invalid"`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid issued_token_type",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-issuedtokentype")(h))
return nil
}
},
wantErr: `failed to exchange token: got unexpected issued_token_type "invalid"`,
},
{
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid JWT",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-jwt")(h))
return nil
}
},
wantErr: `failed to exchange token: received invalid JWT: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts`,
},
{
name: "with requested audience, session cache hit with valid token, and token exchange request succeeds",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithRequestAudience("test-audience")(h))
h.validateIDToken = func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) {
require.Equal(t, "test-audience", audience)
require.Equal(t, "test-id-token-with-requested-audience", token)
return &oidc.IDToken{Expiry: testExchangedToken.IDToken.Expiry.Time}, nil
}
return nil
}
},
wantToken: &testExchangedToken,
},
{
name: "with requested audience, session cache hit with valid refresh token, and token exchange request succeeds",
issuer: successServer.URL,
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
IDToken: &oidctypes.IDToken{
Token: "expired-test-id-token",
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
},
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
}}
t.Cleanup(func() {
cacheKey := SessionCacheKey{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
require.Len(t, cache.sawPutTokens, 1)
require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token)
})
h.cache = cache
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t)
mock.EXPECT().
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
Return(&testToken, nil)
return mock
}
require.NoError(t, WithRequestAudience("test-audience")(h))
h.validateIDToken = func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) {
require.Equal(t, "test-audience", audience)
require.Equal(t, "test-id-token-with-requested-audience", token)
return &oidc.IDToken{Expiry: testExchangedToken.IDToken.Expiry.Time}, nil
}
return nil
}
},
wantToken: &testExchangedToken,
},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt