Add a redirectURI parameter to ExchangeAuthcodeAndValidateTokens() method.

We missed this in the original interface specification, but the `grant_type=authorization_code` requires it, per RFC6749 (https://tools.ietf.org/html/rfc6749#section-4.1.3).

Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
Matt Moyer 2020-12-02 10:36:07 -06:00
parent 4fe691de92
commit fde56164cd
No known key found for this signature in database
GPG Key ID: EAE88AD172C5AE2D
10 changed files with 53 additions and 13 deletions

View File

@ -43,9 +43,9 @@ func (m *MockUpstreamOIDCIdentityProviderI) EXPECT() *MockUpstreamOIDCIdentityPr
} }
// ExchangeAuthcodeAndValidateTokens mocks base method // ExchangeAuthcodeAndValidateTokens mocks base method
func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(arg0 context.Context, arg1 string, arg2 pkce.Code, arg3 nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(arg0 context.Context, arg1 string, arg2 pkce.Code, arg3 nonce.Nonce, arg4 string) (oidctypes.Token, map[string]interface{}, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ExchangeAuthcodeAndValidateTokens", arg0, arg1, arg2, arg3) ret := m.ctrl.Call(m, "ExchangeAuthcodeAndValidateTokens", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(oidctypes.Token) ret0, _ := ret[0].(oidctypes.Token)
ret1, _ := ret[1].(map[string]interface{}) ret1, _ := ret[1].(map[string]interface{})
ret2, _ := ret[2].(error) ret2, _ := ret[2].(error)
@ -53,9 +53,9 @@ func (m *MockUpstreamOIDCIdentityProviderI) ExchangeAuthcodeAndValidateTokens(ar
} }
// ExchangeAuthcodeAndValidateTokens indicates an expected call of ExchangeAuthcodeAndValidateTokens // ExchangeAuthcodeAndValidateTokens indicates an expected call of ExchangeAuthcodeAndValidateTokens
func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ExchangeAuthcodeAndValidateTokens(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) ExchangeAuthcodeAndValidateTokens(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExchangeAuthcodeAndValidateTokens", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ExchangeAuthcodeAndValidateTokens), arg0, arg1, arg2, arg3) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExchangeAuthcodeAndValidateTokens", reflect.TypeOf((*MockUpstreamOIDCIdentityProviderI)(nil).ExchangeAuthcodeAndValidateTokens), arg0, arg1, arg2, arg3, arg4)
} }
// GetAuthorizationURL mocks base method // GetAuthorizationURL mocks base method

View File

@ -42,6 +42,7 @@ func NewHandler(
idpListGetter oidc.IDPListGetter, idpListGetter oidc.IDPListGetter,
oauthHelper fosite.OAuth2Provider, oauthHelper fosite.OAuth2Provider,
stateDecoder, cookieDecoder oidc.Decoder, stateDecoder, cookieDecoder oidc.Decoder,
redirectURI string,
) http.Handler { ) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
state, err := validateRequest(r, stateDecoder, cookieDecoder) state, err := validateRequest(r, stateDecoder, cookieDecoder)
@ -77,6 +78,7 @@ func NewHandler(
authcode(r), authcode(r),
state.PKCECode, state.PKCECode,
state.Nonce, state.Nonce,
redirectURI,
) )
if err != nil { if err != nil {
plog.WarningErr("error exchanging and validating upstream tokens", err, "upstreamName", upstreamIDPConfig.GetName()) plog.WarningErr("error exchanging and validating upstream tokens", err, "upstreamName", upstreamIDPConfig.GetName())

View File

@ -43,6 +43,8 @@ const (
happyUpstreamAuthcode = "upstream-auth-code" happyUpstreamAuthcode = "upstream-auth-code"
happyUpstreamRedirectURI = "https://example.com/callback"
happyDownstreamState = "some-downstream-state-with-at-least-32-bytes" happyDownstreamState = "some-downstream-state-with-at-least-32-bytes"
happyDownstreamCSRF = "test-csrf" happyDownstreamCSRF = "test-csrf"
happyDownstreamPKCE = "test-pkce" happyDownstreamPKCE = "test-pkce"
@ -105,6 +107,7 @@ func TestCallbackEndpoint(t *testing.T) {
Authcode: happyUpstreamAuthcode, Authcode: happyUpstreamAuthcode,
PKCECodeVerifier: pkce.Code(happyDownstreamPKCE), PKCECodeVerifier: pkce.Code(happyDownstreamPKCE),
ExpectedIDTokenNonce: nonce.Nonce(happyDownstreamNonce), ExpectedIDTokenNonce: nonce.Nonce(happyDownstreamNonce),
RedirectURI: happyUpstreamRedirectURI,
} }
// Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it // Note that fosite puts the granted scopes as a param in the redirect URI even though the spec doesn't seem to require it
@ -433,7 +436,7 @@ func TestCallbackEndpoint(t *testing.T) {
oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret) oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret)
idpListGetter := oidctestutil.NewIDPListGetter(&test.idp) idpListGetter := oidctestutil.NewIDPListGetter(&test.idp)
subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec) subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec, happyUpstreamRedirectURI)
req := httptest.NewRequest(test.method, test.path, nil) req := httptest.NewRequest(test.method, test.path, nil)
if test.csrfCookie != "" { if test.csrfCookie != "" {
req.Header.Set("Cookie", test.csrfCookie) req.Header.Set("Cookie", test.csrfCookie)

View File

@ -24,6 +24,7 @@ type ExchangeAuthcodeAndValidateTokenArgs struct {
Authcode string Authcode string
PKCECodeVerifier pkce.Code PKCECodeVerifier pkce.Code
ExpectedIDTokenNonce nonce.Nonce ExpectedIDTokenNonce nonce.Nonce
RedirectURI string
} }
type TestUpstreamOIDCIdentityProvider struct { type TestUpstreamOIDCIdentityProvider struct {
@ -73,6 +74,7 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens(
authcode string, authcode string,
pkceCodeVerifier pkce.Code, pkceCodeVerifier pkce.Code,
expectedIDTokenNonce nonce.Nonce, expectedIDTokenNonce nonce.Nonce,
redirectURI string,
) (oidctypes.Token, map[string]interface{}, error) { ) (oidctypes.Token, map[string]interface{}, error) {
if u.exchangeAuthcodeAndValidateTokensArgs == nil { if u.exchangeAuthcodeAndValidateTokensArgs == nil {
u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0) u.exchangeAuthcodeAndValidateTokensArgs = make([]*ExchangeAuthcodeAndValidateTokenArgs, 0)
@ -83,6 +85,7 @@ func (u *TestUpstreamOIDCIdentityProvider) ExchangeAuthcodeAndValidateTokens(
Authcode: authcode, Authcode: authcode,
PKCECodeVerifier: pkceCodeVerifier, PKCECodeVerifier: pkceCodeVerifier,
ExpectedIDTokenNonce: expectedIDTokenNonce, ExpectedIDTokenNonce: expectedIDTokenNonce,
RedirectURI: redirectURI,
}) })
return u.ExchangeAuthcodeAndValidateTokensFunc(ctx, authcode, pkceCodeVerifier, expectedIDTokenNonce) return u.ExchangeAuthcodeAndValidateTokensFunc(ctx, authcode, pkceCodeVerifier, expectedIDTokenNonce)
} }

View File

@ -42,6 +42,7 @@ type UpstreamOIDCIdentityProviderI interface {
authcode string, authcode string,
pkceCodeVerifier pkce.Code, pkceCodeVerifier pkce.Code,
expectedIDTokenNonce nonce.Nonce, expectedIDTokenNonce nonce.Nonce,
redirectURI string,
) (tokens oidctypes.Token, parsedIDTokenClaims map[string]interface{}, err error) ) (tokens oidctypes.Token, parsedIDTokenClaims map[string]interface{}, err error)
ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) ValidateToken(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error)

View File

@ -83,10 +83,25 @@ func (m *Manager) SetProviders(oidcProviders ...*provider.OIDCProvider) {
encoder.SetSerializer(securecookie.JSONEncoder{}) encoder.SetSerializer(securecookie.JSONEncoder{})
authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath authURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.AuthorizationEndpointPath
m.providerHandlers[authURL] = auth.NewHandler(incomingProvider.Issuer(), m.idpListGetter, oauthHelper, csrftoken.Generate, pkce.Generate, nonce.Generate, encoder, encoder) m.providerHandlers[authURL] = auth.NewHandler(
incomingProvider.Issuer(),
m.idpListGetter,
oauthHelper,
csrftoken.Generate,
pkce.Generate,
nonce.Generate,
encoder,
encoder,
)
callbackURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.CallbackEndpointPath callbackURL := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() + oidc.CallbackEndpointPath
m.providerHandlers[callbackURL] = callback.NewHandler(m.idpListGetter, oauthHelper, encoder, encoder) m.providerHandlers[callbackURL] = callback.NewHandler(
m.idpListGetter,
oauthHelper,
encoder,
encoder,
incomingProvider.Issuer()+oidc.CallbackEndpointPath,
)
plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer()) plog.Debug("oidc provider manager added or updated issuer", "issuer", incomingProvider.Issuer())
} }

View File

@ -61,8 +61,13 @@ func (p *ProviderConfig) GetGroupsClaim() string {
return p.GroupsClaim return p.GroupsClaim
} }
func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce) (oidctypes.Token, map[string]interface{}, error) { func (p *ProviderConfig) ExchangeAuthcodeAndValidateTokens(ctx context.Context, authcode string, pkceCodeVerifier pkce.Code, expectedIDTokenNonce nonce.Nonce, redirectURI string) (oidctypes.Token, map[string]interface{}, error) {
tok, err := p.Config.Exchange(oidc.ClientContext(ctx, p.Client), authcode, pkceCodeVerifier.Verifier()) tok, err := p.Config.Exchange(
oidc.ClientContext(ctx, p.Client),
authcode,
pkceCodeVerifier.Verifier(),
oauth2.SetAuthURLParam("redirect_uri", redirectURI),
)
if err != nil { if err != nil {
return oidctypes.Token{}, nil, err return oidctypes.Token{}, nil, err
} }

View File

@ -181,7 +181,7 @@ func TestProviderConfig(t *testing.T) {
ctx := context.Background() ctx := context.Background()
tok, claims, err := p.ExchangeAuthcodeAndValidateTokens(ctx, tt.authCode, "test-pkce", tt.expectNonce) tok, claims, err := p.ExchangeAuthcodeAndValidateTokens(ctx, tt.authCode, "test-pkce", tt.expectNonce, "https://example.com/callback")
if tt.wantErr != "" { if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr) require.EqualError(t, err, tt.wantErr)
require.Equal(t, oidctypes.Token{}, tok) require.Equal(t, oidctypes.Token{}, tok)

View File

@ -328,7 +328,14 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
// Exchange the authorization code for access, ID, and refresh tokens and perform required // Exchange the authorization code for access, ID, and refresh tokens and perform required
// validations on the returned ID token. // validations on the returned ID token.
token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).ExchangeAuthcodeAndValidateTokens(r.Context(), params.Get("code"), h.pkce, h.nonce) token, _, err := h.getProvider(h.oauth2Config, h.provider, h.httpClient).
ExchangeAuthcodeAndValidateTokens(
r.Context(),
params.Get("code"),
h.pkce,
h.nonce,
h.oauth2Config.RedirectURL,
)
if err != nil { if err != nil {
return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err) return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err)
} }

View File

@ -488,6 +488,8 @@ func TestLogin(t *testing.T) {
} }
func TestHandleAuthCodeCallback(t *testing.T) { func TestHandleAuthCodeCallback(t *testing.T) {
const testRedirectURI = "http://127.0.0.1:12324/callback"
tests := []struct { tests := []struct {
name string name string
method string method string
@ -522,10 +524,11 @@ func TestHandleAuthCodeCallback(t *testing.T) {
wantHTTPStatus: http.StatusBadRequest, wantHTTPStatus: http.StatusBadRequest,
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
Return(oidctypes.Token{}, nil, fmt.Errorf("some exchange error")) Return(oidctypes.Token{}, nil, fmt.Errorf("some exchange error"))
return mock return mock
} }
@ -538,10 +541,11 @@ func TestHandleAuthCodeCallback(t *testing.T) {
query: "state=test-state&code=valid", query: "state=test-state&code=valid",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce")). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
Return(oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil, nil) Return(oidctypes.Token{IDToken: &oidctypes.IDToken{Token: "test-id-token"}}, nil, nil)
return mock return mock
} }