Merge branch 'main' into token-refresh
This commit is contained in:
commit
a9111f39af
@ -13,6 +13,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/coreos/go-oidc"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
clientauthenticationv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1"
|
clientauthenticationv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1"
|
||||||
@ -44,15 +45,17 @@ func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oid
|
|||||||
sessionCachePath string
|
sessionCachePath string
|
||||||
caBundlePaths []string
|
caBundlePaths []string
|
||||||
debugSessionCache bool
|
debugSessionCache bool
|
||||||
|
requestAudience string
|
||||||
)
|
)
|
||||||
cmd.Flags().StringVar(&issuer, "issuer", "", "OpenID Connect issuer URL.")
|
cmd.Flags().StringVar(&issuer, "issuer", "", "OpenID Connect issuer URL.")
|
||||||
cmd.Flags().StringVar(&clientID, "client-id", "", "OpenID Connect client ID.")
|
cmd.Flags().StringVar(&clientID, "client-id", "", "OpenID Connect client ID.")
|
||||||
cmd.Flags().Uint16Var(&listenPort, "listen-port", 0, "TCP port for localhost listener (authorization code flow only).")
|
cmd.Flags().Uint16Var(&listenPort, "listen-port", 0, "TCP port for localhost listener (authorization code flow only).")
|
||||||
cmd.Flags().StringSliceVar(&scopes, "scopes", []string{"offline_access", "openid"}, "OIDC scopes to request during login.")
|
cmd.Flags().StringSliceVar(&scopes, "scopes", []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID}, "OIDC scopes to request during login.")
|
||||||
cmd.Flags().BoolVar(&skipBrowser, "skip-browser", false, "Skip opening the browser (just print the URL).")
|
cmd.Flags().BoolVar(&skipBrowser, "skip-browser", false, "Skip opening the browser (just print the URL).")
|
||||||
cmd.Flags().StringVar(&sessionCachePath, "session-cache", filepath.Join(mustGetConfigDir(), "sessions.yaml"), "Path to session cache file.")
|
cmd.Flags().StringVar(&sessionCachePath, "session-cache", filepath.Join(mustGetConfigDir(), "sessions.yaml"), "Path to session cache file.")
|
||||||
cmd.Flags().StringSliceVar(&caBundlePaths, "ca-bundle", nil, "Path to TLS certificate authority bundle (PEM format, optional, can be repeated).")
|
cmd.Flags().StringSliceVar(&caBundlePaths, "ca-bundle", nil, "Path to TLS certificate authority bundle (PEM format, optional, can be repeated).")
|
||||||
cmd.Flags().BoolVar(&debugSessionCache, "debug-session-cache", false, "Print debug logs related to the session cache.")
|
cmd.Flags().BoolVar(&debugSessionCache, "debug-session-cache", false, "Print debug logs related to the session cache.")
|
||||||
|
cmd.Flags().StringVar(&requestAudience, "request-audience", "", "Request a token with an alternate audience using RF8693 token exchange.")
|
||||||
mustMarkHidden(&cmd, "debug-session-cache")
|
mustMarkHidden(&cmd, "debug-session-cache")
|
||||||
mustMarkRequired(&cmd, "issuer", "client-id")
|
mustMarkRequired(&cmd, "issuer", "client-id")
|
||||||
|
|
||||||
@ -80,6 +83,10 @@ func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oid
|
|||||||
opts = append(opts, oidcclient.WithListenPort(listenPort))
|
opts = append(opts, oidcclient.WithListenPort(listenPort))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if requestAudience != "" {
|
||||||
|
opts = append(opts, oidcclient.WithRequestAudience(requestAudience))
|
||||||
|
}
|
||||||
|
|
||||||
// --skip-browser replaces the default "browser open" function with one that prints to stderr.
|
// --skip-browser replaces the default "browser open" function with one that prints to stderr.
|
||||||
if skipBrowser {
|
if skipBrowser {
|
||||||
opts = append(opts, oidcclient.WithBrowserOpen(func(url string) error {
|
opts = append(opts, oidcclient.WithBrowserOpen(func(url string) error {
|
||||||
|
@ -46,6 +46,7 @@ func TestLoginOIDCCommand(t *testing.T) {
|
|||||||
-h, --help help for oidc
|
-h, --help help for oidc
|
||||||
--issuer string OpenID Connect issuer URL.
|
--issuer string OpenID Connect issuer URL.
|
||||||
--listen-port uint16 TCP port for localhost listener (authorization code flow only).
|
--listen-port uint16 TCP port for localhost listener (authorization code flow only).
|
||||||
|
--request-audience string Request a token with an alternate audience using RF8693 token exchange.
|
||||||
--scopes strings OIDC scopes to request during login. (default [offline_access,openid])
|
--scopes strings OIDC scopes to request during login. (default [offline_access,openid])
|
||||||
--session-cache string Path to session cache file. (default "` + cfgDir + `/sessions.yaml")
|
--session-cache string Path to session cache file. (default "` + cfgDir + `/sessions.yaml")
|
||||||
--skip-browser Skip opening the browser (just print the URL).
|
--skip-browser Skip opening the browser (just print the URL).
|
||||||
|
@ -6,16 +6,19 @@ package oidcclient
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc"
|
"github.com/coreos/go-oidc"
|
||||||
"github.com/pkg/browser"
|
"github.com/pkg/browser"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
|
||||||
|
|
||||||
"go.pinniped.dev/internal/httputil/httperr"
|
"go.pinniped.dev/internal/httputil/httperr"
|
||||||
"go.pinniped.dev/internal/httputil/securityheader"
|
"go.pinniped.dev/internal/httputil/securityheader"
|
||||||
@ -46,6 +49,8 @@ type handlerState struct {
|
|||||||
scopes []string
|
scopes []string
|
||||||
cache SessionCache
|
cache SessionCache
|
||||||
|
|
||||||
|
requestedAudience string
|
||||||
|
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
|
|
||||||
// Parameters of the localhost listener.
|
// Parameters of the localhost listener.
|
||||||
@ -65,6 +70,7 @@ type handlerState struct {
|
|||||||
generateNonce func() (nonce.Nonce, error)
|
generateNonce func() (nonce.Nonce, error)
|
||||||
openURL func(string) error
|
openURL func(string) error
|
||||||
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
|
getProvider func(*oauth2.Config, *oidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI
|
||||||
|
validateIDToken func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error)
|
||||||
|
|
||||||
callbacks chan callbackResult
|
callbacks chan callbackResult
|
||||||
}
|
}
|
||||||
@ -148,6 +154,14 @@ func WithClient(httpClient *http.Client) Option {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WithRequestAudience causes the login flow to perform an additional token exchange using the RFC8693 STS flow.
|
||||||
|
func WithRequestAudience(audience string) Option {
|
||||||
|
return func(h *handlerState) error {
|
||||||
|
h.requestedAudience = audience
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// nopCache is a SessionCache that doesn't actually do anything.
|
// nopCache is a SessionCache that doesn't actually do anything.
|
||||||
type nopCache struct{}
|
type nopCache struct{}
|
||||||
|
|
||||||
@ -160,7 +174,7 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
|
|||||||
issuer: issuer,
|
issuer: issuer,
|
||||||
clientID: clientID,
|
clientID: clientID,
|
||||||
listenAddr: "localhost:0",
|
listenAddr: "localhost:0",
|
||||||
scopes: []string{"offline_access", "openid", "email", "profile"},
|
scopes: []string{oidc.ScopeOfflineAccess, oidc.ScopeOpenID, "email", "profile"},
|
||||||
cache: &nopCache{},
|
cache: &nopCache{},
|
||||||
callbackPath: "/callback",
|
callbackPath: "/callback",
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
@ -173,6 +187,9 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
|
|||||||
generatePKCE: pkce.Generate,
|
generatePKCE: pkce.Generate,
|
||||||
openURL: browser.OpenURL,
|
openURL: browser.OpenURL,
|
||||||
getProvider: upstreamoidc.New,
|
getProvider: upstreamoidc.New,
|
||||||
|
validateIDToken: func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) {
|
||||||
|
return provider.Verifier(&oidc.Config{ClientID: audience}).Verify(ctx, token)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
if err := opt(&h); err != nil {
|
if err := opt(&h); err != nil {
|
||||||
@ -201,6 +218,26 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Do the basic login to get an access and ID token issued to our main client ID.
|
||||||
|
baseToken, err := h.baseLogin()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is no requested audience, or the requested audience matches the one we got, we're done.
|
||||||
|
if h.requestedAudience == "" || (baseToken.IDToken != nil && h.requestedAudience == baseToken.IDToken.Claims["aud"]) {
|
||||||
|
return baseToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perform the RFC8693 token exchange.
|
||||||
|
exchangedToken, err := h.tokenExchangeRFC8693(baseToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to exchange token: %w", err)
|
||||||
|
}
|
||||||
|
return exchangedToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handlerState) baseLogin() (*oidctypes.Token, error) {
|
||||||
// Check the cache for a previous session issued with the same parameters.
|
// Check the cache for a previous session issued with the same parameters.
|
||||||
sort.Strings(h.scopes)
|
sort.Strings(h.scopes)
|
||||||
cacheKey := SessionCacheKey{
|
cacheKey := SessionCacheKey{
|
||||||
@ -217,21 +254,13 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Perform OIDC discovery.
|
// Perform OIDC discovery.
|
||||||
h.provider, err = oidc.NewProvider(h.ctx, h.issuer)
|
if err := h.initOIDCDiscovery(); err != nil {
|
||||||
if err != nil {
|
return nil, err
|
||||||
return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
|
|
||||||
h.oauth2Config = &oauth2.Config{
|
|
||||||
ClientID: h.clientID,
|
|
||||||
Endpoint: h.provider.Endpoint(),
|
|
||||||
Scopes: h.scopes,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If there was a cached refresh token, attempt to use the refresh flow instead of a fresh login.
|
// If there was a cached refresh token, attempt to use the refresh flow instead of a fresh login.
|
||||||
if cached != nil && cached.RefreshToken != nil && cached.RefreshToken.Token != "" {
|
if cached != nil && cached.RefreshToken != nil && cached.RefreshToken.Token != "" {
|
||||||
freshToken, err := h.handleRefresh(ctx, cached.RefreshToken)
|
freshToken, err := h.handleRefresh(h.ctx, cached.RefreshToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -282,6 +311,95 @@ func Login(issuer string, clientID string, opts ...Option) (*oidctypes.Token, er
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *handlerState) initOIDCDiscovery() error {
|
||||||
|
// Make this method idempotent so it can be called in multiple cases with no extra network requests.
|
||||||
|
if h.provider != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
h.provider, err = oidc.NewProvider(h.ctx, h.issuer)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
|
||||||
|
h.oauth2Config = &oauth2.Config{
|
||||||
|
ClientID: h.clientID,
|
||||||
|
Endpoint: h.provider.Endpoint(),
|
||||||
|
Scopes: h.scopes,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *handlerState) tokenExchangeRFC8693(baseToken *oidctypes.Token) (*oidctypes.Token, error) {
|
||||||
|
// Perform OIDC discovery. This may have already been performed if there was not a cached base token.
|
||||||
|
if err := h.initOIDCDiscovery(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the base access token to authenticate our request. This will populate the "authorization" header.
|
||||||
|
client := oauth2.NewClient(h.ctx, oauth2.StaticTokenSource(&oauth2.Token{AccessToken: baseToken.AccessToken.Token}))
|
||||||
|
|
||||||
|
// Form the HTTP POST request with the parameters specified by RFC8693.
|
||||||
|
reqBody := strings.NewReader(url.Values{
|
||||||
|
"grant_type": []string{"urn:ietf:params:oauth:grant-type:token-exchange"},
|
||||||
|
"audience": []string{h.requestedAudience},
|
||||||
|
"subject_token": []string{baseToken.AccessToken.Token},
|
||||||
|
"subject_token_type": []string{"urn:ietf:params:oauth:token-type:access_token"},
|
||||||
|
"requested_token_type": []string{"urn:ietf:params:oauth:token-type:jwt"},
|
||||||
|
}.Encode())
|
||||||
|
req, err := http.NewRequestWithContext(h.ctx, http.MethodPost, h.oauth2Config.Endpoint.TokenURL, reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not build RFC8693 request: %w", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("content-type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
// Perform the request.
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
|
// Expect an HTTP 200 response with "application/json" content type.
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("unexpected HTTP response status %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
if contentType := resp.Header.Get("content-type"); contentType != "application/json" {
|
||||||
|
return nil, fmt.Errorf("unexpected HTTP response content type %q", contentType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the JSON response body.
|
||||||
|
var respBody struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
IssuedTokenType string `json:"issued_token_type"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect the token_type and issued_token_type response parameters to have some known values.
|
||||||
|
if respBody.TokenType != "N_A" {
|
||||||
|
return nil, fmt.Errorf("got unexpected token_type %q", respBody.TokenType)
|
||||||
|
}
|
||||||
|
if respBody.IssuedTokenType != "urn:ietf:params:oauth:token-type:jwt" {
|
||||||
|
return nil, fmt.Errorf("got unexpected issued_token_type %q", respBody.IssuedTokenType)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the returned JWT to make sure we got the audience we wanted and extract the expiration time.
|
||||||
|
stsToken, err := h.validateIDToken(h.ctx, h.provider, h.requestedAudience, respBody.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("received invalid JWT: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &oidctypes.Token{IDToken: &oidctypes.IDToken{
|
||||||
|
Token: respBody.AccessToken,
|
||||||
|
Expiry: metav1.NewTime(stsToken.Expiry),
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctypes.RefreshToken) (*oidctypes.Token, error) {
|
func (h *handlerState) handleRefresh(ctx context.Context, refreshToken *oidctypes.RefreshToken) (*oidctypes.Token, error) {
|
||||||
ctx, cancel := context.WithTimeout(ctx, refreshTimeout)
|
ctx, cancel := context.WithTimeout(ctx, refreshTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
@ -62,12 +62,36 @@ func TestLogin(t *testing.T) {
|
|||||||
IDToken: &oidctypes.IDToken{Token: "test-id-token", Expiry: metav1.NewTime(time1.Add(2 * time.Minute))},
|
IDToken: &oidctypes.IDToken{Token: "test-id-token", Expiry: metav1.NewTime(time1.Add(2 * time.Minute))},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
testExchangedToken := oidctypes.Token{
|
||||||
|
IDToken: &oidctypes.IDToken{Token: "test-id-token-with-requested-audience", Expiry: metav1.NewTime(time1.Add(3 * time.Minute))},
|
||||||
|
}
|
||||||
|
|
||||||
// Start a test server that returns 500 errors
|
// Start a test server that returns 500 errors
|
||||||
errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
http.Error(w, "some discovery error", http.StatusInternalServerError)
|
http.Error(w, "some discovery error", http.StatusInternalServerError)
|
||||||
}))
|
}))
|
||||||
t.Cleanup(errorServer.Close)
|
t.Cleanup(errorServer.Close)
|
||||||
|
|
||||||
|
// Start a test server that returns discovery data with a broken token URL
|
||||||
|
brokenTokenURLMux := http.NewServeMux()
|
||||||
|
brokenTokenURLServer := httptest.NewServer(brokenTokenURLMux)
|
||||||
|
brokenTokenURLMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("content-type", "application/json")
|
||||||
|
type providerJSON struct {
|
||||||
|
Issuer string `json:"issuer"`
|
||||||
|
AuthURL string `json:"authorization_endpoint"`
|
||||||
|
TokenURL string `json:"token_endpoint"`
|
||||||
|
JWKSURL string `json:"jwks_uri"`
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(&providerJSON{
|
||||||
|
Issuer: brokenTokenURLServer.URL,
|
||||||
|
AuthURL: brokenTokenURLServer.URL + "/authorize",
|
||||||
|
TokenURL: "%",
|
||||||
|
JWKSURL: brokenTokenURLServer.URL + "/keys",
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Cleanup(brokenTokenURLServer.Close)
|
||||||
|
|
||||||
// Start a test server that returns a real discovery document and answers refresh requests.
|
// Start a test server that returns a real discovery document and answers refresh requests.
|
||||||
providerMux := http.NewServeMux()
|
providerMux := http.NewServeMux()
|
||||||
successServer := httptest.NewServer(providerMux)
|
successServer := httptest.NewServer(providerMux)
|
||||||
@ -100,20 +124,21 @@ func TestLogin(t *testing.T) {
|
|||||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if r.Form.Get("client_id") != "test-client-id" {
|
|
||||||
http.Error(w, "expected client_id 'test-client-id'", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if r.Form.Get("grant_type") != "refresh_token" {
|
|
||||||
http.Error(w, "expected refresh_token grant type", http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var response struct {
|
var response struct {
|
||||||
oauth2.Token
|
oauth2.Token
|
||||||
IDToken string `json:"id_token,omitempty"`
|
IDToken string `json:"id_token,omitempty"`
|
||||||
ExpiresIn int64 `json:"expires_in"`
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
IssuedTokenType string `json:"issued_token_type,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch r.Form.Get("grant_type") {
|
||||||
|
case "refresh_token":
|
||||||
|
if r.Form.Get("client_id") != "test-client-id" {
|
||||||
|
http.Error(w, "expected client_id 'test-client-id'", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
response.AccessToken = testToken.AccessToken.Token
|
response.AccessToken = testToken.AccessToken.Token
|
||||||
response.ExpiresIn = int64(time.Until(testToken.AccessToken.Expiry.Time).Seconds())
|
response.ExpiresIn = int64(time.Until(testToken.AccessToken.Expiry.Time).Seconds())
|
||||||
response.RefreshToken = testToken.RefreshToken.Token
|
response.RefreshToken = testToken.RefreshToken.Token
|
||||||
@ -126,6 +151,41 @@ func TestLogin(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "urn:ietf:params:oauth:grant-type:token-exchange":
|
||||||
|
switch r.Form.Get("audience") {
|
||||||
|
case "test-audience-produce-invalid-http-response":
|
||||||
|
http.Redirect(w, r, "%", http.StatusTemporaryRedirect)
|
||||||
|
return
|
||||||
|
case "test-audience-produce-http-400":
|
||||||
|
http.Error(w, "some server error", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
case "test-audience-produce-wrong-content-type":
|
||||||
|
w.Header().Set("content-type", "invalid")
|
||||||
|
return
|
||||||
|
case "test-audience-produce-invalid-json":
|
||||||
|
w.Header().Set("content-type", "application/json")
|
||||||
|
_, _ = w.Write([]byte(`{`))
|
||||||
|
return
|
||||||
|
case "test-audience-produce-invalid-tokentype":
|
||||||
|
response.TokenType = "invalid"
|
||||||
|
case "test-audience-produce-invalid-issuedtokentype":
|
||||||
|
response.TokenType = "N_A"
|
||||||
|
response.IssuedTokenType = "invalid"
|
||||||
|
case "test-audience-produce-invalid-jwt":
|
||||||
|
response.TokenType = "N_A"
|
||||||
|
response.IssuedTokenType = "urn:ietf:params:oauth:token-type:jwt"
|
||||||
|
response.AccessToken = "not-a-valid-jwt"
|
||||||
|
default:
|
||||||
|
response.TokenType = "N_A"
|
||||||
|
response.IssuedTokenType = "urn:ietf:params:oauth:token-type:jwt"
|
||||||
|
response.AccessToken = testExchangedToken.IDToken.Token
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
http.Error(w, fmt.Sprintf("invalid grant_type %q", r.Form.Get("grant_type")), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
w.Header().Set("content-type", "application/json")
|
w.Header().Set("content-type", "application/json")
|
||||||
require.NoError(t, json.NewEncoder(w).Encode(&response))
|
require.NoError(t, json.NewEncoder(w).Encode(&response))
|
||||||
})
|
})
|
||||||
@ -444,6 +504,289 @@ func TestLogin(t *testing.T) {
|
|||||||
issuer: successServer.URL,
|
issuer: successServer.URL,
|
||||||
wantToken: &testToken,
|
wantToken: &testToken,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "with requested audience, session cache hit with valid token, but discovery fails",
|
||||||
|
clientID: "test-client-id",
|
||||||
|
issuer: errorServer.URL,
|
||||||
|
opt: func(t *testing.T) Option {
|
||||||
|
return func(h *handlerState) error {
|
||||||
|
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.Equal(t, []SessionCacheKey{{
|
||||||
|
Issuer: errorServer.URL,
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Scopes: []string{"test-scope"},
|
||||||
|
RedirectURI: "http://localhost:0/callback",
|
||||||
|
}}, cache.sawGetKeys)
|
||||||
|
require.Empty(t, cache.sawPutTokens)
|
||||||
|
})
|
||||||
|
require.NoError(t, WithSessionCache(cache)(h))
|
||||||
|
require.NoError(t, WithRequestAudience("cluster-1234")(h))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantErr: fmt.Sprintf("failed to exchange token: could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with requested audience, session cache hit with valid token, but token URL is invalid",
|
||||||
|
issuer: brokenTokenURLServer.URL,
|
||||||
|
clientID: "test-client-id",
|
||||||
|
opt: func(t *testing.T) Option {
|
||||||
|
return func(h *handlerState) error {
|
||||||
|
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.Equal(t, []SessionCacheKey{{
|
||||||
|
Issuer: brokenTokenURLServer.URL,
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Scopes: []string{"test-scope"},
|
||||||
|
RedirectURI: "http://localhost:0/callback",
|
||||||
|
}}, cache.sawGetKeys)
|
||||||
|
require.Empty(t, cache.sawPutTokens)
|
||||||
|
})
|
||||||
|
require.NoError(t, WithSessionCache(cache)(h))
|
||||||
|
require.NoError(t, WithRequestAudience("cluster-1234")(h))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantErr: `failed to exchange token: could not build RFC8693 request: parse "%": invalid URL escape "%"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with requested audience, session cache hit with valid token, but token exchange request fails",
|
||||||
|
issuer: successServer.URL,
|
||||||
|
clientID: "test-client-id",
|
||||||
|
opt: func(t *testing.T) Option {
|
||||||
|
return func(h *handlerState) error {
|
||||||
|
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.Equal(t, []SessionCacheKey{{
|
||||||
|
Issuer: successServer.URL,
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Scopes: []string{"test-scope"},
|
||||||
|
RedirectURI: "http://localhost:0/callback",
|
||||||
|
}}, cache.sawGetKeys)
|
||||||
|
require.Empty(t, cache.sawPutTokens)
|
||||||
|
})
|
||||||
|
require.NoError(t, WithSessionCache(cache)(h))
|
||||||
|
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-http-response")(h))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantErr: fmt.Sprintf(`failed to exchange token: Post "%s/token": failed to parse Location header "%%": parse "%%": invalid URL escape "%%"`, successServer.URL),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with requested audience, session cache hit with valid token, but token exchange request returns non-200",
|
||||||
|
issuer: successServer.URL,
|
||||||
|
clientID: "test-client-id",
|
||||||
|
opt: func(t *testing.T) Option {
|
||||||
|
return func(h *handlerState) error {
|
||||||
|
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.Equal(t, []SessionCacheKey{{
|
||||||
|
Issuer: successServer.URL,
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Scopes: []string{"test-scope"},
|
||||||
|
RedirectURI: "http://localhost:0/callback",
|
||||||
|
}}, cache.sawGetKeys)
|
||||||
|
require.Empty(t, cache.sawPutTokens)
|
||||||
|
})
|
||||||
|
require.NoError(t, WithSessionCache(cache)(h))
|
||||||
|
require.NoError(t, WithRequestAudience("test-audience-produce-http-400")(h))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantErr: `failed to exchange token: unexpected HTTP response status 400`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with requested audience, session cache hit with valid token, but token exchange request returns wrong content-type",
|
||||||
|
issuer: successServer.URL,
|
||||||
|
clientID: "test-client-id",
|
||||||
|
opt: func(t *testing.T) Option {
|
||||||
|
return func(h *handlerState) error {
|
||||||
|
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.Equal(t, []SessionCacheKey{{
|
||||||
|
Issuer: successServer.URL,
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Scopes: []string{"test-scope"},
|
||||||
|
RedirectURI: "http://localhost:0/callback",
|
||||||
|
}}, cache.sawGetKeys)
|
||||||
|
require.Empty(t, cache.sawPutTokens)
|
||||||
|
})
|
||||||
|
require.NoError(t, WithSessionCache(cache)(h))
|
||||||
|
require.NoError(t, WithRequestAudience("test-audience-produce-wrong-content-type")(h))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantErr: `failed to exchange token: unexpected HTTP response content type "invalid"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid JSON",
|
||||||
|
issuer: successServer.URL,
|
||||||
|
clientID: "test-client-id",
|
||||||
|
opt: func(t *testing.T) Option {
|
||||||
|
return func(h *handlerState) error {
|
||||||
|
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.Equal(t, []SessionCacheKey{{
|
||||||
|
Issuer: successServer.URL,
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Scopes: []string{"test-scope"},
|
||||||
|
RedirectURI: "http://localhost:0/callback",
|
||||||
|
}}, cache.sawGetKeys)
|
||||||
|
require.Empty(t, cache.sawPutTokens)
|
||||||
|
})
|
||||||
|
require.NoError(t, WithSessionCache(cache)(h))
|
||||||
|
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-json")(h))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantErr: `failed to exchange token: failed to decode response: unexpected EOF`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid token_type",
|
||||||
|
issuer: successServer.URL,
|
||||||
|
clientID: "test-client-id",
|
||||||
|
opt: func(t *testing.T) Option {
|
||||||
|
return func(h *handlerState) error {
|
||||||
|
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.Equal(t, []SessionCacheKey{{
|
||||||
|
Issuer: successServer.URL,
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Scopes: []string{"test-scope"},
|
||||||
|
RedirectURI: "http://localhost:0/callback",
|
||||||
|
}}, cache.sawGetKeys)
|
||||||
|
require.Empty(t, cache.sawPutTokens)
|
||||||
|
})
|
||||||
|
require.NoError(t, WithSessionCache(cache)(h))
|
||||||
|
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-tokentype")(h))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantErr: `failed to exchange token: got unexpected token_type "invalid"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid issued_token_type",
|
||||||
|
issuer: successServer.URL,
|
||||||
|
clientID: "test-client-id",
|
||||||
|
opt: func(t *testing.T) Option {
|
||||||
|
return func(h *handlerState) error {
|
||||||
|
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.Equal(t, []SessionCacheKey{{
|
||||||
|
Issuer: successServer.URL,
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Scopes: []string{"test-scope"},
|
||||||
|
RedirectURI: "http://localhost:0/callback",
|
||||||
|
}}, cache.sawGetKeys)
|
||||||
|
require.Empty(t, cache.sawPutTokens)
|
||||||
|
})
|
||||||
|
require.NoError(t, WithSessionCache(cache)(h))
|
||||||
|
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-issuedtokentype")(h))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantErr: `failed to exchange token: got unexpected issued_token_type "invalid"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with requested audience, session cache hit with valid token, but token exchange request returns invalid JWT",
|
||||||
|
issuer: successServer.URL,
|
||||||
|
clientID: "test-client-id",
|
||||||
|
opt: func(t *testing.T) Option {
|
||||||
|
return func(h *handlerState) error {
|
||||||
|
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.Equal(t, []SessionCacheKey{{
|
||||||
|
Issuer: successServer.URL,
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Scopes: []string{"test-scope"},
|
||||||
|
RedirectURI: "http://localhost:0/callback",
|
||||||
|
}}, cache.sawGetKeys)
|
||||||
|
require.Empty(t, cache.sawPutTokens)
|
||||||
|
})
|
||||||
|
require.NoError(t, WithSessionCache(cache)(h))
|
||||||
|
require.NoError(t, WithRequestAudience("test-audience-produce-invalid-jwt")(h))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantErr: `failed to exchange token: received invalid JWT: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with requested audience, session cache hit with valid token, and token exchange request succeeds",
|
||||||
|
issuer: successServer.URL,
|
||||||
|
clientID: "test-client-id",
|
||||||
|
opt: func(t *testing.T) Option {
|
||||||
|
return func(h *handlerState) error {
|
||||||
|
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.Equal(t, []SessionCacheKey{{
|
||||||
|
Issuer: successServer.URL,
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Scopes: []string{"test-scope"},
|
||||||
|
RedirectURI: "http://localhost:0/callback",
|
||||||
|
}}, cache.sawGetKeys)
|
||||||
|
require.Empty(t, cache.sawPutTokens)
|
||||||
|
})
|
||||||
|
require.NoError(t, WithSessionCache(cache)(h))
|
||||||
|
require.NoError(t, WithRequestAudience("test-audience")(h))
|
||||||
|
|
||||||
|
h.validateIDToken = func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) {
|
||||||
|
require.Equal(t, "test-audience", audience)
|
||||||
|
require.Equal(t, "test-id-token-with-requested-audience", token)
|
||||||
|
return &oidc.IDToken{Expiry: testExchangedToken.IDToken.Expiry.Time}, nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantToken: &testExchangedToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with requested audience, session cache hit with valid refresh token, and token exchange request succeeds",
|
||||||
|
issuer: successServer.URL,
|
||||||
|
clientID: "test-client-id",
|
||||||
|
opt: func(t *testing.T) Option {
|
||||||
|
return func(h *handlerState) error {
|
||||||
|
cache := &mockSessionCache{t: t, getReturnsToken: &oidctypes.Token{
|
||||||
|
IDToken: &oidctypes.IDToken{
|
||||||
|
Token: "expired-test-id-token",
|
||||||
|
Expiry: metav1.Now(), // less than Now() + minIDTokenValidity
|
||||||
|
},
|
||||||
|
RefreshToken: &oidctypes.RefreshToken{Token: "test-refresh-token"},
|
||||||
|
}}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cacheKey := SessionCacheKey{
|
||||||
|
Issuer: successServer.URL,
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Scopes: []string{"test-scope"},
|
||||||
|
RedirectURI: "http://localhost:0/callback",
|
||||||
|
}
|
||||||
|
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
|
||||||
|
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
|
||||||
|
require.Len(t, cache.sawPutTokens, 1)
|
||||||
|
require.Equal(t, testToken.IDToken.Token, cache.sawPutTokens[0].IDToken.Token)
|
||||||
|
})
|
||||||
|
h.cache = cache
|
||||||
|
|
||||||
|
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
|
||||||
|
mock := mockUpstream(t)
|
||||||
|
mock.EXPECT().
|
||||||
|
ValidateToken(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce("")).
|
||||||
|
Return(&testToken, nil)
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(t, WithRequestAudience("test-audience")(h))
|
||||||
|
|
||||||
|
h.validateIDToken = func(ctx context.Context, provider *oidc.Provider, audience string, token string) (*oidc.IDToken, error) {
|
||||||
|
require.Equal(t, "test-audience", audience)
|
||||||
|
require.Equal(t, "test-id-token-with-requested-audience", token)
|
||||||
|
return &oidc.IDToken{Expiry: testExchangedToken.IDToken.Expiry.Time}, nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
},
|
||||||
|
wantToken: &testExchangedToken,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
tt := tt
|
tt := tt
|
||||||
|
Loading…
Reference in New Issue
Block a user