Check for issuer if available
Signed-off-by: Margo Crawford <margaretc@vmware.com>
This commit is contained in:
parent
0cd086cf9c
commit
c9cf13a01f
@ -141,7 +141,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
||||
if len(validatedTokens.IDToken.Claims) != 0 {
|
||||
newSub := claims["sub"]
|
||||
oldDownstreamSubject := session.Fosite.Claims.Subject
|
||||
oldSub, err := upstreamoidc.ExtractUpstreamSubjectFromDownstream(oldDownstreamSubject)
|
||||
oldIss, oldSub, err := upstreamoidc.ExtractUpstreamSubjectAndIssuerFromDownstream(oldDownstreamSubject)
|
||||
if err != nil {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf("Upstream refresh failed.").
|
||||
WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
@ -152,14 +152,17 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
||||
}
|
||||
usernameClaim := p.GetUsernameClaim()
|
||||
newUsername := claims[usernameClaim]
|
||||
oldUsername := session.Fosite.Claims.Extra["username"]
|
||||
// its possible this won't be returned.
|
||||
// but if it is, verify that it hasn't changed.
|
||||
if newUsername != nil {
|
||||
oldUsername := session.Fosite.Claims.Extra["username"]
|
||||
if oldUsername != newUsername {
|
||||
if newUsername != nil && oldUsername != newUsername {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh failed.").WithWrap(errors.New("username in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
newIssuer := claims["iss"]
|
||||
if newIssuer != nil && oldIss != newIssuer {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh failed.").WithWrap(errors.New("issuer in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1706,6 +1706,37 @@ func TestRefreshGrant(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "refresh grant with changed issuer claim",
|
||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
|
||||
upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").WithValidatedTokens(&oidctypes.Token{
|
||||
IDToken: &oidctypes.IDToken{
|
||||
Claims: map[string]interface{}{
|
||||
"some-claim": "some-value",
|
||||
"sub": "some-subject",
|
||||
"iss": "some-changed-issuer",
|
||||
},
|
||||
},
|
||||
}).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()),
|
||||
authcodeExchange: authcodeExchangeInputs{
|
||||
customSessionData: initialUpstreamOIDCCustomSessionData(),
|
||||
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()),
|
||||
},
|
||||
refreshRequest: refreshRequestInputs{
|
||||
want: tokenEndpointResponseExpectedValues{
|
||||
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(),
|
||||
wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()),
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantErrorResponseBody: here.Doc(`
|
||||
{
|
||||
"error": "error",
|
||||
"error_description": "Error during upstream refresh. Upstream refresh failed."
|
||||
}
|
||||
`),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "upstream ldap refresh happy path",
|
||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{
|
||||
|
@ -238,11 +238,12 @@ func (p *ProviderConfig) tryRevokeRefreshToken(
|
||||
}
|
||||
}
|
||||
|
||||
func ExtractUpstreamSubjectFromDownstream(downstreamSubject string) (string, error) {
|
||||
func ExtractUpstreamSubjectAndIssuerFromDownstream(downstreamSubject string) (string, string, error) {
|
||||
if !strings.Contains(downstreamSubject, "?sub=") {
|
||||
return "", errors.New("downstream subject did not contain original upstream subject")
|
||||
return "", "", errors.New("downstream subject did not contain original upstream subject")
|
||||
}
|
||||
return strings.SplitN(downstreamSubject, "?sub=", 2)[1], nil
|
||||
split := strings.SplitN(downstreamSubject, "?sub=", 2)
|
||||
return split[0], split[1], nil
|
||||
}
|
||||
|
||||
// ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response,
|
||||
|
@ -910,22 +910,37 @@ func TestProviderConfig(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ExtractUpstreamSubjectFromDownstream", func(t *testing.T) {
|
||||
t.Run("ExtractUpstreamSubjectAndIssuerFromDownstream", func(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
downstreamSubject string
|
||||
wantUpstreamSubject string
|
||||
wantUpstreamIssuer string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
downstreamSubject: "https://some-issuer?sub=some-subject",
|
||||
wantUpstreamSubject: "some-subject",
|
||||
wantUpstreamIssuer: "https://some-issuer",
|
||||
},
|
||||
{
|
||||
name: "happy path but sub is empty string", // todo i think this should not be the responsibility of this function, even though it's undesirable behavior...
|
||||
downstreamSubject: "https://some-issuer?sub=",
|
||||
wantUpstreamSubject: "",
|
||||
wantUpstreamIssuer: "https://some-issuer",
|
||||
},
|
||||
{
|
||||
name: "happy path but iss is empty string",
|
||||
downstreamSubject: "?sub=some-subject",
|
||||
wantUpstreamSubject: "some-subject",
|
||||
wantUpstreamIssuer: "",
|
||||
},
|
||||
{
|
||||
name: "subject in a subject",
|
||||
downstreamSubject: "https://some-other-issuer?sub=https://some-issuer?sub=some-subject",
|
||||
wantUpstreamSubject: "https://some-issuer?sub=some-subject",
|
||||
wantUpstreamIssuer: "https://some-other-issuer",
|
||||
},
|
||||
{
|
||||
name: "doesn't contain sub=",
|
||||
@ -936,17 +951,17 @@ func TestProviderConfig(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actualUpstreamSubject, err := ExtractUpstreamSubjectFromDownstream(tt.downstreamSubject)
|
||||
actualUpstreamIssuer, actualUpstreamSubject, err := ExtractUpstreamSubjectAndIssuerFromDownstream(tt.downstreamSubject)
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
require.Equal(t, tt.wantErr, err.Error())
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.wantUpstreamSubject, actualUpstreamSubject)
|
||||
require.Equal(t, tt.wantUpstreamIssuer, actualUpstreamIssuer)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
})
|
||||
|
||||
t.Run("ExchangeAuthcodeAndValidateTokens", func(t *testing.T) {
|
||||
|
Loading…
Reference in New Issue
Block a user