Check for issuer if available
Signed-off-by: Margo Crawford <margaretc@vmware.com>
This commit is contained in:
parent
0cd086cf9c
commit
c9cf13a01f
@ -141,7 +141,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
|||||||
if len(validatedTokens.IDToken.Claims) != 0 {
|
if len(validatedTokens.IDToken.Claims) != 0 {
|
||||||
newSub := claims["sub"]
|
newSub := claims["sub"]
|
||||||
oldDownstreamSubject := session.Fosite.Claims.Subject
|
oldDownstreamSubject := session.Fosite.Claims.Subject
|
||||||
oldSub, err := upstreamoidc.ExtractUpstreamSubjectFromDownstream(oldDownstreamSubject)
|
oldIss, oldSub, err := upstreamoidc.ExtractUpstreamSubjectAndIssuerFromDownstream(oldDownstreamSubject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf("Upstream refresh failed.").
|
return errorsx.WithStack(errUpstreamRefreshError.WithHintf("Upstream refresh failed.").
|
||||||
WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||||
@ -152,14 +152,17 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
|||||||
}
|
}
|
||||||
usernameClaim := p.GetUsernameClaim()
|
usernameClaim := p.GetUsernameClaim()
|
||||||
newUsername := claims[usernameClaim]
|
newUsername := claims[usernameClaim]
|
||||||
|
oldUsername := session.Fosite.Claims.Extra["username"]
|
||||||
// its possible this won't be returned.
|
// its possible this won't be returned.
|
||||||
// but if it is, verify that it hasn't changed.
|
// but if it is, verify that it hasn't changed.
|
||||||
if newUsername != nil {
|
if newUsername != nil && oldUsername != newUsername {
|
||||||
oldUsername := session.Fosite.Claims.Extra["username"]
|
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||||
if oldUsername != newUsername {
|
"Upstream refresh failed.").WithWrap(errors.New("username in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
}
|
||||||
"Upstream refresh failed.").WithWrap(errors.New("username in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
newIssuer := claims["iss"]
|
||||||
}
|
if newIssuer != nil && oldIss != newIssuer {
|
||||||
|
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||||
|
"Upstream refresh failed.").WithWrap(errors.New("issuer in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1706,6 +1706,37 @@ func TestRefreshGrant(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "refresh grant with changed issuer claim",
|
||||||
|
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
|
||||||
|
upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").WithValidatedTokens(&oidctypes.Token{
|
||||||
|
IDToken: &oidctypes.IDToken{
|
||||||
|
Claims: map[string]interface{}{
|
||||||
|
"some-claim": "some-value",
|
||||||
|
"sub": "some-subject",
|
||||||
|
"iss": "some-changed-issuer",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()),
|
||||||
|
authcodeExchange: authcodeExchangeInputs{
|
||||||
|
customSessionData: initialUpstreamOIDCCustomSessionData(),
|
||||||
|
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
|
||||||
|
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()),
|
||||||
|
},
|
||||||
|
refreshRequest: refreshRequestInputs{
|
||||||
|
want: tokenEndpointResponseExpectedValues{
|
||||||
|
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(),
|
||||||
|
wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()),
|
||||||
|
wantStatus: http.StatusUnauthorized,
|
||||||
|
wantErrorResponseBody: here.Doc(`
|
||||||
|
{
|
||||||
|
"error": "error",
|
||||||
|
"error_description": "Error during upstream refresh. Upstream refresh failed."
|
||||||
|
}
|
||||||
|
`),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "upstream ldap refresh happy path",
|
name: "upstream ldap refresh happy path",
|
||||||
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{
|
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{
|
||||||
|
@ -238,11 +238,12 @@ func (p *ProviderConfig) tryRevokeRefreshToken(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExtractUpstreamSubjectFromDownstream(downstreamSubject string) (string, error) {
|
func ExtractUpstreamSubjectAndIssuerFromDownstream(downstreamSubject string) (string, string, error) {
|
||||||
if !strings.Contains(downstreamSubject, "?sub=") {
|
if !strings.Contains(downstreamSubject, "?sub=") {
|
||||||
return "", errors.New("downstream subject did not contain original upstream subject")
|
return "", "", errors.New("downstream subject did not contain original upstream subject")
|
||||||
}
|
}
|
||||||
return strings.SplitN(downstreamSubject, "?sub=", 2)[1], nil
|
split := strings.SplitN(downstreamSubject, "?sub=", 2)
|
||||||
|
return split[0], split[1], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response,
|
// ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response,
|
||||||
|
@ -910,22 +910,37 @@ func TestProviderConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("ExtractUpstreamSubjectFromDownstream", func(t *testing.T) {
|
t.Run("ExtractUpstreamSubjectAndIssuerFromDownstream", func(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
downstreamSubject string
|
downstreamSubject string
|
||||||
wantUpstreamSubject string
|
wantUpstreamSubject string
|
||||||
|
wantUpstreamIssuer string
|
||||||
wantErr string
|
wantErr string
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "happy path",
|
name: "happy path",
|
||||||
downstreamSubject: "https://some-issuer?sub=some-subject",
|
downstreamSubject: "https://some-issuer?sub=some-subject",
|
||||||
wantUpstreamSubject: "some-subject",
|
wantUpstreamSubject: "some-subject",
|
||||||
|
wantUpstreamIssuer: "https://some-issuer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "happy path but sub is empty string", // todo i think this should not be the responsibility of this function, even though it's undesirable behavior...
|
||||||
|
downstreamSubject: "https://some-issuer?sub=",
|
||||||
|
wantUpstreamSubject: "",
|
||||||
|
wantUpstreamIssuer: "https://some-issuer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "happy path but iss is empty string",
|
||||||
|
downstreamSubject: "?sub=some-subject",
|
||||||
|
wantUpstreamSubject: "some-subject",
|
||||||
|
wantUpstreamIssuer: "",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "subject in a subject",
|
name: "subject in a subject",
|
||||||
downstreamSubject: "https://some-other-issuer?sub=https://some-issuer?sub=some-subject",
|
downstreamSubject: "https://some-other-issuer?sub=https://some-issuer?sub=some-subject",
|
||||||
wantUpstreamSubject: "https://some-issuer?sub=some-subject",
|
wantUpstreamSubject: "https://some-issuer?sub=some-subject",
|
||||||
|
wantUpstreamIssuer: "https://some-other-issuer",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "doesn't contain sub=",
|
name: "doesn't contain sub=",
|
||||||
@ -936,17 +951,17 @@ func TestProviderConfig(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
tt := tt
|
tt := tt
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
actualUpstreamSubject, err := ExtractUpstreamSubjectFromDownstream(tt.downstreamSubject)
|
actualUpstreamIssuer, actualUpstreamSubject, err := ExtractUpstreamSubjectAndIssuerFromDownstream(tt.downstreamSubject)
|
||||||
if tt.wantErr != "" {
|
if tt.wantErr != "" {
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Equal(t, tt.wantErr, err.Error())
|
require.Equal(t, tt.wantErr, err.Error())
|
||||||
} else {
|
} else {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, tt.wantUpstreamSubject, actualUpstreamSubject)
|
require.Equal(t, tt.wantUpstreamSubject, actualUpstreamSubject)
|
||||||
|
require.Equal(t, tt.wantUpstreamIssuer, actualUpstreamIssuer)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("ExchangeAuthcodeAndValidateTokens", func(t *testing.T) {
|
t.Run("ExchangeAuthcodeAndValidateTokens", func(t *testing.T) {
|
||||||
|
Loading…
Reference in New Issue
Block a user