Check for issuer if available

Signed-off-by: Margo Crawford <margaretc@vmware.com>
This commit is contained in:
Margo Crawford 2021-12-14 15:27:08 -08:00
parent 0cd086cf9c
commit c9cf13a01f
4 changed files with 63 additions and 13 deletions

View File

@ -141,7 +141,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
if len(validatedTokens.IDToken.Claims) != 0 {
newSub := claims["sub"]
oldDownstreamSubject := session.Fosite.Claims.Subject
oldSub, err := upstreamoidc.ExtractUpstreamSubjectFromDownstream(oldDownstreamSubject)
oldIss, oldSub, err := upstreamoidc.ExtractUpstreamSubjectAndIssuerFromDownstream(oldDownstreamSubject)
if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf("Upstream refresh failed.").
WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
@ -152,14 +152,17 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
}
usernameClaim := p.GetUsernameClaim()
newUsername := claims[usernameClaim]
oldUsername := session.Fosite.Claims.Extra["username"]
// its possible this won't be returned.
// but if it is, verify that it hasn't changed.
if newUsername != nil {
oldUsername := session.Fosite.Claims.Extra["username"]
if oldUsername != newUsername {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
"Upstream refresh failed.").WithWrap(errors.New("username in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
}
if newUsername != nil && oldUsername != newUsername {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
"Upstream refresh failed.").WithWrap(errors.New("username in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
}
newIssuer := claims["iss"]
if newIssuer != nil && oldIss != newIssuer {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
"Upstream refresh failed.").WithWrap(errors.New("issuer in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
}
}

View File

@ -1706,6 +1706,37 @@ func TestRefreshGrant(t *testing.T) {
},
},
},
{
name: "refresh grant with changed issuer claim",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithOIDC(
upstreamOIDCIdentityProviderBuilder().WithUsernameClaim("username-claim").WithValidatedTokens(&oidctypes.Token{
IDToken: &oidctypes.IDToken{
Claims: map[string]interface{}{
"some-claim": "some-value",
"sub": "some-subject",
"iss": "some-changed-issuer",
},
},
}).WithRefreshedTokens(refreshedUpstreamTokensWithIDAndRefreshTokens()).Build()),
authcodeExchange: authcodeExchangeInputs{
customSessionData: initialUpstreamOIDCCustomSessionData(),
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(initialUpstreamOIDCCustomSessionData()),
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: happyOIDCUpstreamRefreshCall(),
wantUpstreamOIDCValidateTokenCall: happyUpstreamValidateTokenCall(refreshedUpstreamTokensWithIDAndRefreshTokens()),
wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(`
{
"error": "error",
"error_description": "Error during upstream refresh. Upstream refresh failed."
}
`),
},
},
},
{
name: "upstream ldap refresh happy path",
idps: oidctestutil.NewUpstreamIDPListerBuilder().WithLDAP(&oidctestutil.TestUpstreamLDAPIdentityProvider{

View File

@ -238,11 +238,12 @@ func (p *ProviderConfig) tryRevokeRefreshToken(
}
}
func ExtractUpstreamSubjectFromDownstream(downstreamSubject string) (string, error) {
func ExtractUpstreamSubjectAndIssuerFromDownstream(downstreamSubject string) (string, string, error) {
if !strings.Contains(downstreamSubject, "?sub=") {
return "", errors.New("downstream subject did not contain original upstream subject")
return "", "", errors.New("downstream subject did not contain original upstream subject")
}
return strings.SplitN(downstreamSubject, "?sub=", 2)[1], nil
split := strings.SplitN(downstreamSubject, "?sub=", 2)
return split[0], split[1], nil
}
// ValidateToken will validate the ID token. It will also merge the claims from the userinfo endpoint response,

View File

@ -910,22 +910,37 @@ func TestProviderConfig(t *testing.T) {
}
})
t.Run("ExtractUpstreamSubjectFromDownstream", func(t *testing.T) {
t.Run("ExtractUpstreamSubjectAndIssuerFromDownstream", func(t *testing.T) {
tests := []struct {
name string
downstreamSubject string
wantUpstreamSubject string
wantUpstreamIssuer string
wantErr string
}{
{
name: "happy path",
downstreamSubject: "https://some-issuer?sub=some-subject",
wantUpstreamSubject: "some-subject",
wantUpstreamIssuer: "https://some-issuer",
},
{
name: "happy path but sub is empty string", // todo i think this should not be the responsibility of this function, even though it's undesirable behavior...
downstreamSubject: "https://some-issuer?sub=",
wantUpstreamSubject: "",
wantUpstreamIssuer: "https://some-issuer",
},
{
name: "happy path but iss is empty string",
downstreamSubject: "?sub=some-subject",
wantUpstreamSubject: "some-subject",
wantUpstreamIssuer: "",
},
{
name: "subject in a subject",
downstreamSubject: "https://some-other-issuer?sub=https://some-issuer?sub=some-subject",
wantUpstreamSubject: "https://some-issuer?sub=some-subject",
wantUpstreamIssuer: "https://some-other-issuer",
},
{
name: "doesn't contain sub=",
@ -936,17 +951,17 @@ func TestProviderConfig(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
actualUpstreamSubject, err := ExtractUpstreamSubjectFromDownstream(tt.downstreamSubject)
actualUpstreamIssuer, actualUpstreamSubject, err := ExtractUpstreamSubjectAndIssuerFromDownstream(tt.downstreamSubject)
if tt.wantErr != "" {
require.Error(t, err)
require.Equal(t, tt.wantErr, err.Error())
} else {
require.NoError(t, err)
require.Equal(t, tt.wantUpstreamSubject, actualUpstreamSubject)
require.Equal(t, tt.wantUpstreamIssuer, actualUpstreamIssuer)
}
})
}
})
t.Run("ExchangeAuthcodeAndValidateTokens", func(t *testing.T) {