Refactored oidcUpstreamRefresh
Various style changes, updated some comments and variable names and extracted a helper function for validation.
This commit is contained in:
parent
62be761ef1
commit
5b161be334
@ -18,6 +18,7 @@ import (
|
|||||||
"go.pinniped.dev/internal/oidc/provider"
|
"go.pinniped.dev/internal/oidc/provider"
|
||||||
"go.pinniped.dev/internal/plog"
|
"go.pinniped.dev/internal/plog"
|
||||||
"go.pinniped.dev/internal/psession"
|
"go.pinniped.dev/internal/psession"
|
||||||
|
"go.pinniped.dev/pkg/oidcclient/oidctypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -105,10 +106,12 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
|||||||
if s.OIDC == nil {
|
if s.OIDC == nil {
|
||||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||||
}
|
}
|
||||||
|
|
||||||
accessTokenStored := s.OIDC.UpstreamAccessToken != ""
|
accessTokenStored := s.OIDC.UpstreamAccessToken != ""
|
||||||
refreshTokenStored := s.OIDC.UpstreamRefreshToken != ""
|
refreshTokenStored := s.OIDC.UpstreamRefreshToken != ""
|
||||||
refreshTokenOrAccessTokenStored := (accessTokenStored || refreshTokenStored) && !(accessTokenStored && refreshTokenStored)
|
|
||||||
if !refreshTokenOrAccessTokenStored {
|
exactlyOneTokenStored := (accessTokenStored || refreshTokenStored) && !(accessTokenStored && refreshTokenStored)
|
||||||
|
if !exactlyOneTokenStored {
|
||||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,9 +132,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
|||||||
).WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
).WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
tokens = &oauth2.Token{
|
tokens = &oauth2.Token{AccessToken: s.OIDC.UpstreamAccessToken}
|
||||||
AccessToken: s.OIDC.UpstreamAccessToken,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upstream refresh may or may not return a new ID token. From the spec:
|
// Upstream refresh may or may not return a new ID token. From the spec:
|
||||||
@ -147,41 +148,16 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
|||||||
"Upstream refresh returned an invalid ID token or UserInfo response.").WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
"Upstream refresh returned an invalid ID token or UserInfo response.").WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||||
}
|
}
|
||||||
|
|
||||||
claims := validatedTokens.IDToken.Claims
|
err = validateIdentityUnchangedSinceInitialLogin(validatedTokens, session, p.GetUsernameClaim())
|
||||||
// if we have any claims at all, we better have a subject, and it better match the previous value.
|
if err != nil {
|
||||||
// but it's possible that we don't because both returning a new refresh token on refresh and having a userinfo
|
return err
|
||||||
// endpoint are optional.
|
|
||||||
if len(validatedTokens.IDToken.Claims) != 0 { //nolint:nestif
|
|
||||||
newSub, hasSub := getString(claims, oidc.IDTokenSubjectClaim)
|
|
||||||
if !hasSub {
|
|
||||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
|
||||||
"Upstream refresh failed.").WithWrap(errors.New("subject in upstream refresh not found")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
|
||||||
}
|
|
||||||
if s.OIDC.UpstreamSubject != newSub {
|
|
||||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
|
||||||
"Upstream refresh failed.").WithWrap(errors.New("subject in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
|
||||||
}
|
|
||||||
usernameClaim := p.GetUsernameClaim()
|
|
||||||
newUsername, hasUsername := getString(claims, usernameClaim)
|
|
||||||
oldUsername := session.Fosite.Claims.Extra[oidc.DownstreamUsernameClaim]
|
|
||||||
// its possible this won't be returned.
|
|
||||||
// but if it is, verify that it hasn't changed.
|
|
||||||
if hasUsername && oldUsername != newUsername {
|
|
||||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
|
||||||
"Upstream refresh failed.").WithWrap(errors.New("username in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
|
||||||
}
|
|
||||||
newIssuer, hasIssuer := getString(claims, oidc.IDTokenIssuerClaim)
|
|
||||||
if hasIssuer && s.OIDC.UpstreamIssuer != newIssuer {
|
|
||||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
|
||||||
"Upstream refresh failed.").WithWrap(errors.New("issuer in upstream refresh does not match previous value")).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Upstream refresh may or may not return a new refresh token. If we got a new refresh token, then update it in
|
// Upstream refresh may or may not return a new refresh token. If we got a new refresh token, then update it in
|
||||||
// the user's session. If we did not get a new refresh token, then keep the old one in the session by avoiding
|
// the user's session. If we did not get a new refresh token, then keep the old one in the session by avoiding
|
||||||
// overwriting the old one.
|
// overwriting the old one.
|
||||||
if tokens.RefreshToken != "" {
|
if tokens.RefreshToken != "" {
|
||||||
plog.Debug("upstream refresh request did not return a new refresh token",
|
plog.Debug("upstream refresh request returned a new refresh token",
|
||||||
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
|
"providerName", s.ProviderName, "providerType", s.ProviderType, "providerUID", s.ProviderUID)
|
||||||
s.OIDC.UpstreamRefreshToken = tokens.RefreshToken
|
s.OIDC.UpstreamRefreshToken = tokens.RefreshToken
|
||||||
}
|
}
|
||||||
@ -189,6 +165,51 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateIdentityUnchangedSinceInitialLogin(validatedTokens *oidctypes.Token, session *psession.PinnipedSession, usernameClaimName string) error {
|
||||||
|
s := session.Custom
|
||||||
|
mergedClaims := validatedTokens.IDToken.Claims
|
||||||
|
|
||||||
|
// If we have any claims at all, we better have a subject, and it better match the previous value.
|
||||||
|
// but it's possible that we don't because both returning a new id token on refresh and having a userinfo
|
||||||
|
// endpoint are optional.
|
||||||
|
if len(mergedClaims) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newSub, hasSub := getString(mergedClaims, oidc.IDTokenSubjectClaim)
|
||||||
|
if !hasSub {
|
||||||
|
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||||
|
"Upstream refresh failed.").WithWrap(errors.New("subject in upstream refresh not found")).
|
||||||
|
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||||
|
}
|
||||||
|
if s.OIDC.UpstreamSubject != newSub {
|
||||||
|
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||||
|
"Upstream refresh failed.").WithWrap(errors.New("subject in upstream refresh does not match previous value")).
|
||||||
|
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||||
|
}
|
||||||
|
|
||||||
|
newUsername, hasUsername := getString(mergedClaims, usernameClaimName)
|
||||||
|
oldUsername := session.Fosite.Claims.Extra[oidc.DownstreamUsernameClaim]
|
||||||
|
// It's possible that a username wasn't returned by the upstream provider during refresh,
|
||||||
|
// but if it is, verify that it hasn't changed.
|
||||||
|
if hasUsername && oldUsername != newUsername {
|
||||||
|
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||||
|
"Upstream refresh failed.").WithWrap(errors.New("username in upstream refresh does not match previous value")).
|
||||||
|
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||||
|
}
|
||||||
|
|
||||||
|
newIssuer, hasIssuer := getString(mergedClaims, oidc.IDTokenIssuerClaim)
|
||||||
|
// It's possible that an issuer wasn't returned by the upstream provider during refresh,
|
||||||
|
// but if it is, verify that it hasn't changed.
|
||||||
|
if hasIssuer && s.OIDC.UpstreamIssuer != newIssuer {
|
||||||
|
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||||
|
"Upstream refresh failed.").WithWrap(errors.New("issuer in upstream refresh does not match previous value")).
|
||||||
|
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func getString(m map[string]interface{}, key string) (string, bool) {
|
func getString(m map[string]interface{}, key string) (string, bool) {
|
||||||
val, ok := m[key].(string)
|
val, ok := m[key].(string)
|
||||||
return val, ok
|
return val, ok
|
||||||
|
Loading…
Reference in New Issue
Block a user