From c9cf13a01fbf5735010c9be4ff72fc1969b677c0 Mon Sep 17 00:00:00 2001 From: Margo Crawford Date: Tue, 14 Dec 2021 15:27:08 -0800 Subject: [PATCH] Check for issuer if available Signed-off-by: Margo Crawford --- internal/oidc/token/token_handler.go | 17 +++++++----- internal/oidc/token/token_handler_test.go | 31 ++++++++++++++++++++++ internal/upstreamoidc/upstreamoidc.go | 7 ++--- internal/upstreamoidc/upstreamoidc_test.go | 21 ++++++++++++--- 4 files changed, 63 insertions(+), 13 deletions(-) diff --git a/internal/oidc/token/token_handler.go b/internal/oidc/token/token_handler.go index 9e27be10..9ec64dc5 100644 --- a/internal/oidc/token/token_handler.go +++ b/internal/oidc/token/token_handler.go @@ -141,7 +141,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, if len(validatedTokens.IDToken.Claims) != 0 { newSub := claims["sub"] oldDownstreamSubject := session.Fosite.Claims.Subject - oldSub, err := upstreamoidc.ExtractUpstreamSubjectFromDownstream(oldDownstreamSubject) + oldIss, oldSub, err := upstreamoidc.ExtractUpstreamSubjectAndIssuerFromDownstream(oldDownstreamSubject) if err != nil { return errorsx.WithStack(errUpstreamRefreshError.WithHintf("Upstream refresh failed."). WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) @@ -152,14 +152,17 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, } usernameClaim := p.GetUsernameClaim() newUsername := claims[usernameClaim] + oldUsername := session.Fosite.Claims.Extra["username"] // its possible this won't be returned. // but if it is, verify that it hasn't changed. - if newUsername != nil { - oldUsername := session.Fosite.Claims.Extra["username"] - if oldUsername != newUsername { - return errorsx.WithStack(errUpstreamRefreshError.WithHintf( - "Upstream refresh failed.").WithWrap(errors.New("username in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) - } + if newUsername != nil && oldUsername != newUsername { + return errorsx.WithStack(errUpstreamRefreshError.WithHintf( + "Upstream refresh failed.").WithWrap(errors.New("username in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) + } + newIssuer := claims["iss"] + if newIssuer != nil && oldIss != newIssuer { + return errorsx.WithStack(errUpstreamRefreshError.WithHintf( + "Upstream refresh failed.").WithWrap(errors.New("issuer in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) } } diff --git a/internal/oidc/token/token_handler_test.go b/internal/oidc/token/token_handler_test.go index eb81a280..b53987a4 100644 --- a/internal/oidc/token/token_handler_test.go +++ b/internal/oidc/token/token_handler_test.go @@ -1706,6 +1706,37 @@ func TestRefreshGrant(t *testing.T) { }, }, }, + { + name: "refresh grant with changed issuer claim", + idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC( + upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").WithValidatedTokens(&oidctypes.Token{ + IDToken: &oidctypes.IDToken{ + Claims: map[string]interface{}{ + "some-claim": "some-value", + "sub": "some-subject", + "iss": "some-changed-issuer", + }, + }, + }).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()), + authcodeExchange: authcodeExchangeInputs{ + customSessionData: initialUpstreamOIDCCustomSessionData(), + modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") }, + want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()), + }, + refreshRequest: refreshRequestInputs{ + want: tokenEndpointResponseExpectedValues{ + wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(), + wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()), + wantStatus: http.StatusUnauthorized, + wantErrorResponseBody: here.Doc(` + { + "error": "error", + "error_description": "Error during upstream refresh. Upstream refresh failed." + } + `), + }, + }, + }, { name: "upstream ldap refresh happy path", idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{ diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index 1c7b7d8c..53a6eab6 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -238,11 +238,12 @@ func (p *ProviderConfig) tryRevokeRefreshToken( } } -func ExtractUpstreamSubjectFromDownstream(downstreamSubject string) (string, error) { +func ExtractUpstreamSubjectAndIssuerFromDownstream(downstreamSubject string) (string, string, error) { if !strings.Contains(downstreamSubject, "?sub=") { - return "", errors.New("downstream subject did not contain original upstream subject") + return "", "", errors.New("downstream subject did not contain original upstream subject") } - return strings.SplitN(downstreamSubject, "?sub=", 2)[1], nil + split := strings.SplitN(downstreamSubject, "?sub=", 2) + return split[0], split[1], nil } // ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response, diff --git a/internal/upstreamoidc/upstreamoidc_test.go b/internal/upstreamoidc/upstreamoidc_test.go index 7d5552f4..91152c59 100644 --- a/internal/upstreamoidc/upstreamoidc_test.go +++ b/internal/upstreamoidc/upstreamoidc_test.go @@ -910,22 +910,37 @@ func TestProviderConfig(t *testing.T) { } }) - t.Run("ExtractUpstreamSubjectFromDownstream", func(t *testing.T) { + t.Run("ExtractUpstreamSubjectAndIssuerFromDownstream", func(t *testing.T) { tests := []struct { name string downstreamSubject string wantUpstreamSubject string + wantUpstreamIssuer string wantErr string }{ { name: "happy path", downstreamSubject: "https://some-issuer?sub=some-subject", wantUpstreamSubject: "some-subject", + wantUpstreamIssuer: "https://some-issuer", + }, + { + name: "happy path but sub is empty string", // todo i think this should not be the responsibility of this function, even though it's undesirable behavior... + downstreamSubject: "https://some-issuer?sub=", + wantUpstreamSubject: "", + wantUpstreamIssuer: "https://some-issuer", + }, + { + name: "happy path but iss is empty string", + downstreamSubject: "?sub=some-subject", + wantUpstreamSubject: "some-subject", + wantUpstreamIssuer: "", }, { name: "subject in a subject", downstreamSubject: "https://some-other-issuer?sub=https://some-issuer?sub=some-subject", wantUpstreamSubject: "https://some-issuer?sub=some-subject", + wantUpstreamIssuer: "https://some-other-issuer", }, { name: "doesn't contain sub=", @@ -936,17 +951,17 @@ func TestProviderConfig(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - actualUpstreamSubject, err := ExtractUpstreamSubjectFromDownstream(tt.downstreamSubject) + actualUpstreamIssuer, actualUpstreamSubject, err := ExtractUpstreamSubjectAndIssuerFromDownstream(tt.downstreamSubject) if tt.wantErr != "" { require.Error(t, err) require.Equal(t, tt.wantErr, err.Error()) } else { require.NoError(t, err) require.Equal(t, tt.wantUpstreamSubject, actualUpstreamSubject) + require.Equal(t, tt.wantUpstreamIssuer, actualUpstreamIssuer) } }) } - }) t.Run("ExchangeAuthcodeAndValidateTokens", func(t *testing.T) {