Refactored oidcUpstreamRefresh
Various style changes, updated some comments and variable names and extracted a helper function for validation.
This commit is contained in:
parent
62be761ef1
commit
5b161be334
@ -18,6 +18,7 @@ import (
|
||||
"go.pinniped.dev/internal/oidc/provider"
|
||||
"go.pinniped.dev/internal/plog"
|
||||
"go.pinniped.dev/internal/psession"
|
||||
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -105,10 +106,12 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
||||
if s.OIDC == nil {
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
}
|
||||
|
||||
accessTokenStored := s.OIDC.UpstreamAccessToken != ""
|
||||
refreshTokenStored := s.OIDC.UpstreamRefreshToken != ""
|
||||
refreshTokenOrAccessTokenStored := (accessTokenStored || refreshTokenStored) && !(accessTokenStored && refreshTokenStored)
|
||||
if !refreshTokenOrAccessTokenStored {
|
||||
|
||||
exactlyOneTokenStored := (accessTokenStored || refreshTokenStored) && !(accessTokenStored && refreshTokenStored)
|
||||
if !exactlyOneTokenStored {
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
}
|
||||
|
||||
@ -129,9 +132,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
||||
).WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
} else {
|
||||
tokens = &oauth2.Token{
|
||||
AccessToken: s.OIDC.UpstreamAccessToken,
|
||||
}
|
||||
tokens = &oauth2.Token{AccessToken: s.OIDC.UpstreamAccessToken}
|
||||
}
|
||||
|
||||
// Upstream refresh may or may not return a new ID token. From the spec:
|
||||
@ -147,41 +148,16 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
||||
"Upstream refresh returned an invalid ID token or UserInfo response.").WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
|
||||
claims := validatedTokens.IDToken.Claims
|
||||
// if we have any claims at all, we better have a subject, and it better match the previous value.
|
||||
// but it's possible that we don't because both returning a new refresh token on refresh and having a userinfo
|
||||
// endpoint are optional.
|
||||
if len(validatedTokens.IDToken.Claims) != 0 { //nolint:nestif
|
||||
newSub, hasSub := getString(claims, oidc.IDTokenSubjectClaim)
|
||||
if !hasSub {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh failed.").WithWrap(errors.New("subject in upstream refresh not found")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
if s.OIDC.UpstreamSubject != newSub {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh failed.").WithWrap(errors.New("subject in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
usernameClaim := p.GetUsernameClaim()
|
||||
newUsername, hasUsername := getString(claims, usernameClaim)
|
||||
oldUsername := session.Fosite.Claims.Extra[oidc.DownstreamUsernameClaim]
|
||||
// its possible this won't be returned.
|
||||
// but if it is, verify that it hasn't changed.
|
||||
if hasUsername && oldUsername != newUsername {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh failed.").WithWrap(errors.New("username in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
newIssuer, hasIssuer := getString(claims, oidc.IDTokenIssuerClaim)
|
||||
if hasIssuer && s.OIDC.UpstreamIssuer != newIssuer {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh failed.").WithWrap(errors.New("issuer in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
err = validateIdentityUnchangedSinceInitialLogin(validatedTokens, session, p.GetUsernameClaim())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Upstream refresh may or may not return a new refresh token. If we got a new refresh token, then update it in
|
||||
// the user's session. If we did not get a new refresh token, then keep the old one in the session by avoiding
|
||||
// overwriting the old one.
|
||||
if tokens.RefreshToken != "" {
|
||||
plog.Debug("upstream refresh request did not return a new refresh token",
|
||||
plog.Debug("upstream refresh request returned a new refresh token",
|
||||
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
|
||||
s.OIDC.UpstreamRefreshToken = tokens.RefreshToken
|
||||
}
|
||||
@ -189,6 +165,51 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateIdentityUnchangedSinceInitialLogin(validatedTokens *oidctypes.Token, session *psession.PinnipedSession, usernameClaimName string) error {
|
||||
s := session.Custom
|
||||
mergedClaims := validatedTokens.IDToken.Claims
|
||||
|
||||
// If we have any claims at all, we better have a subject, and it better match the previous value.
|
||||
// but it's possible that we don't because both returning a new id token on refresh and having a userinfo
|
||||
// endpoint are optional.
|
||||
if len(mergedClaims) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
newSub, hasSub := getString(mergedClaims, oidc.IDTokenSubjectClaim)
|
||||
if !hasSub {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh failed.").WithWrap(errors.New("subject in upstream refresh not found")).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
if s.OIDC.UpstreamSubject != newSub {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh failed.").WithWrap(errors.New("subject in upstream refresh does not match previous value")).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
|
||||
newUsername, hasUsername := getString(mergedClaims, usernameClaimName)
|
||||
oldUsername := session.Fosite.Claims.Extra[oidc.DownstreamUsernameClaim]
|
||||
// It's possible that a username wasn't returned by the upstream provider during refresh,
|
||||
// but if it is, verify that it hasn't changed.
|
||||
if hasUsername && oldUsername != newUsername {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh failed.").WithWrap(errors.New("username in upstream refresh does not match previous value")).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
|
||||
newIssuer, hasIssuer := getString(mergedClaims, oidc.IDTokenIssuerClaim)
|
||||
// It's possible that an issuer wasn't returned by the upstream provider during refresh,
|
||||
// but if it is, verify that it hasn't changed.
|
||||
if hasIssuer && s.OIDC.UpstreamIssuer != newIssuer {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh failed.").WithWrap(errors.New("issuer in upstream refresh does not match previous value")).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getString(m map[string]interface{}, key string) (string, bool) {
|
||||
val, ok := m[key].(string)
|
||||
return val, ok
|
||||
|
Loading…
Reference in New Issue
Block a user