Remove our direct dependency on ory/x

ory/x has new releases very often, sometimes multiple times per week,
causing a lot of noise from dependabot. We were barely using it
directly, so replace our direct usages with equivalent code.
This commit is contained in:
Ryan Richard 2022-03-24 10:24:54 -07:00
parent 42bd385cbd
commit 48c5a625a5
3 changed files with 62 additions and 742 deletions

2
go.mod
View File

@ -55,7 +55,6 @@ require (
github.com/joshlf/go-acl v0.0.0-20200411065538-eae00ae38531 github.com/joshlf/go-acl v0.0.0-20200411065538-eae00ae38531
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
github.com/ory/fosite v0.42.1 github.com/ory/fosite v0.42.1
github.com/ory/x v0.0.352
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/sclevine/agouti v3.0.0+incompatible github.com/sclevine/agouti v3.0.0+incompatible
@ -141,6 +140,7 @@ require (
github.com/ory/go-acc v0.2.7 // indirect github.com/ory/go-acc v0.2.7 // indirect
github.com/ory/go-convenience v0.1.0 // indirect github.com/ory/go-convenience v0.1.0 // indirect
github.com/ory/viper v1.7.5 // indirect github.com/ory/viper v1.7.5 // indirect
github.com/ory/x v0.0.214 // indirect
github.com/pborman/uuid v1.2.1 // indirect github.com/pborman/uuid v1.2.1 // indirect
github.com/pelletier/go-toml v1.9.4 // indirect github.com/pelletier/go-toml v1.9.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect

686
go.sum

File diff suppressed because it is too large Load Diff

View File

@ -11,7 +11,7 @@ import (
"net/http" "net/http"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/x/errorsx" errorsx "github.com/pkg/errors"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apiserver/pkg/warning" "k8s.io/apiserver/pkg/warning"
@ -24,21 +24,6 @@ import (
"go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/psession"
) )
var (
errMissingUpstreamSessionInternalError = &fosite.RFC6749Error{
ErrorField: "error",
DescriptionField: "There was an internal server error.",
HintField: "Required upstream data not found in session.",
CodeField: http.StatusInternalServerError,
}
errUpstreamRefreshError = &fosite.RFC6749Error{
ErrorField: "error",
DescriptionField: "Error during upstream refresh.",
CodeField: http.StatusUnauthorized,
}
)
func NewHandler( func NewHandler(
idpLister oidc.UpstreamIdentityProvidersLister, idpLister oidc.UpstreamIdentityProvidersLister,
oauthHelper fosite.OAuth2Provider, oauthHelper fosite.OAuth2Provider,
@ -91,17 +76,34 @@ func NewHandler(
}) })
} }
func errMissingUpstreamSessionInternalError() *fosite.RFC6749Error {
return &fosite.RFC6749Error{
ErrorField: "error",
DescriptionField: "There was an internal server error.",
HintField: "Required upstream data not found in session.",
CodeField: http.StatusInternalServerError,
}
}
func errUpstreamRefreshError() *fosite.RFC6749Error {
return &fosite.RFC6749Error{
ErrorField: "error",
DescriptionField: "Error during upstream refresh.",
CodeField: http.StatusUnauthorized,
}
}
func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error { func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error {
session := accessRequest.GetSession().(*psession.PinnipedSession) session := accessRequest.GetSession().(*psession.PinnipedSession)
customSessionData := session.Custom customSessionData := session.Custom
if customSessionData == nil { if customSessionData == nil {
return errorsx.WithStack(errMissingUpstreamSessionInternalError) return errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
providerName := customSessionData.ProviderName providerName := customSessionData.ProviderName
providerUID := customSessionData.ProviderUID providerUID := customSessionData.ProviderUID
if providerUID == "" || providerName == "" { if providerUID == "" || providerName == "" {
return errorsx.WithStack(errMissingUpstreamSessionInternalError) return errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
switch customSessionData.ProviderType { switch customSessionData.ProviderType {
@ -112,14 +114,14 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
case psession.ProviderTypeActiveDirectory: case psession.ProviderTypeActiveDirectory:
return upstreamLDAPRefresh(ctx, providerCache, session) return upstreamLDAPRefresh(ctx, providerCache, session)
default: default:
return errorsx.WithStack(errMissingUpstreamSessionInternalError) return errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
} }
func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, providerCache oidc.UpstreamIdentityProvidersLister) error { func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, providerCache oidc.UpstreamIdentityProvidersLister) error {
s := session.Custom s := session.Custom
if s.OIDC == nil { if s.OIDC == nil {
return errorsx.WithStack(errMissingUpstreamSessionInternalError) return errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
accessTokenStored := s.OIDC.UpstreamAccessToken != "" accessTokenStored := s.OIDC.UpstreamAccessToken != ""
@ -127,7 +129,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
exactlyOneTokenStored := (accessTokenStored || refreshTokenStored) && !(accessTokenStored && refreshTokenStored) exactlyOneTokenStored := (accessTokenStored || refreshTokenStored) && !(accessTokenStored && refreshTokenStored)
if !exactlyOneTokenStored { if !exactlyOneTokenStored {
return errorsx.WithStack(errMissingUpstreamSessionInternalError) return errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
p, err := findOIDCProviderByNameAndValidateUID(s, providerCache) p, err := findOIDCProviderByNameAndValidateUID(s, providerCache)
@ -142,9 +144,9 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
if refreshTokenStored { if refreshTokenStored {
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken) tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken)
if err != nil { if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHint( return errUpstreamRefreshError().WithHint(
"Upstream refresh failed.", "Upstream refresh failed.",
).WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) ).WithTrace(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
} }
} else { } else {
tokens = &oauth2.Token{AccessToken: s.OIDC.UpstreamAccessToken} tokens = &oauth2.Token{AccessToken: s.OIDC.UpstreamAccessToken}
@ -163,9 +165,9 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
// least some providers do not include one, so we skip the nonce validation here (but not other validations). // least some providers do not include one, so we skip the nonce validation here (but not other validations).
validatedTokens, err := p.ValidateTokenAndMergeWithUserInfo(ctx, tokens, "", hasIDTok, accessTokenStored) validatedTokens, err := p.ValidateTokenAndMergeWithUserInfo(ctx, tokens, "", hasIDTok, accessTokenStored)
if err != nil { if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf( return errUpstreamRefreshError().WithHintf(
"Upstream refresh returned an invalid ID token or UserInfo response.").WithWrap(err). "Upstream refresh returned an invalid ID token or UserInfo response.").WithTrace(err).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
} }
mergedClaims := validatedTokens.IDToken.Claims mergedClaims := validatedTokens.IDToken.Claims
@ -184,9 +186,9 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
// and let any old groups memberships in the session remain. // and let any old groups memberships in the session remain.
refreshedGroups, err := downstreamsession.GetGroupsFromUpstreamIDToken(p, mergedClaims) refreshedGroups, err := downstreamsession.GetGroupsFromUpstreamIDToken(p, mergedClaims)
if err != nil { if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf( return errUpstreamRefreshError().WithHintf(
"Upstream refresh error while extracting groups claim.").WithWrap(err). "Upstream refresh error while extracting groups claim.").WithTrace(err).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
} }
if refreshedGroups != nil { if refreshedGroups != nil {
oldGroups, err := getDownstreamGroupsFromPinnipedSession(session) oldGroups, err := getDownstreamGroupsFromPinnipedSession(session)
@ -234,14 +236,14 @@ func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interfac
newSub, hasSub := getString(mergedClaims, oidc.IDTokenSubjectClaim) newSub, hasSub := getString(mergedClaims, oidc.IDTokenSubjectClaim)
if !hasSub { if !hasSub {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf( return errUpstreamRefreshError().WithHintf(
"Upstream refresh failed.").WithWrap(errors.New("subject in upstream refresh not found")). "Upstream refresh failed.").WithTrace(errors.New("subject in upstream refresh not found")).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
} }
if s.OIDC.UpstreamSubject != newSub { if s.OIDC.UpstreamSubject != newSub {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf( return errUpstreamRefreshError().WithHintf(
"Upstream refresh failed.").WithWrap(errors.New("subject in upstream refresh does not match previous value")). "Upstream refresh failed.").WithTrace(errors.New("subject in upstream refresh does not match previous value")).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
} }
newUsername, hasUsername := getString(mergedClaims, usernameClaimName) newUsername, hasUsername := getString(mergedClaims, usernameClaimName)
@ -249,18 +251,18 @@ func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interfac
// It's possible that a username wasn't returned by the upstream provider during refresh, // It's possible that a username wasn't returned by the upstream provider during refresh,
// but if it is, verify that it hasn't changed. // but if it is, verify that it hasn't changed.
if hasUsername && oldUsername != newUsername { if hasUsername && oldUsername != newUsername {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf( return errUpstreamRefreshError().WithHintf(
"Upstream refresh failed.").WithWrap(errors.New("username in upstream refresh does not match previous value")). "Upstream refresh failed.").WithTrace(errors.New("username in upstream refresh does not match previous value")).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
} }
newIssuer, hasIssuer := getString(mergedClaims, oidc.IDTokenIssuerClaim) newIssuer, hasIssuer := getString(mergedClaims, oidc.IDTokenIssuerClaim)
// It's possible that an issuer wasn't returned by the upstream provider during refresh, // It's possible that an issuer wasn't returned by the upstream provider during refresh,
// but if it is, verify that it hasn't changed. // but if it is, verify that it hasn't changed.
if hasIssuer && s.OIDC.UpstreamIssuer != newIssuer { if hasIssuer && s.OIDC.UpstreamIssuer != newIssuer {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf( return errUpstreamRefreshError().WithHintf(
"Upstream refresh failed.").WithWrap(errors.New("issuer in upstream refresh does not match previous value")). "Upstream refresh failed.").WithTrace(errors.New("issuer in upstream refresh does not match previous value")).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
} }
return nil return nil
@ -278,13 +280,13 @@ func findOIDCProviderByNameAndValidateUID(
for _, p := range providerCache.GetOIDCIdentityProviders() { for _, p := range providerCache.GetOIDCIdentityProviders() {
if p.GetName() == s.ProviderName { if p.GetName() == s.ProviderName {
if p.GetResourceUID() != s.ProviderUID { if p.GetResourceUID() != s.ProviderUID {
return nil, errorsx.WithStack(errUpstreamRefreshError.WithHint( return nil, errorsx.WithStack(errUpstreamRefreshError().WithHint(
"Provider from upstream session data has changed its resource UID since authentication.")) "Provider from upstream session data has changed its resource UID since authentication."))
} }
return p, nil return p, nil
} }
} }
return nil, errorsx.WithStack(errUpstreamRefreshError. return nil, errorsx.WithStack(errUpstreamRefreshError().
WithHint("Provider from upstream session data was not found."). WithHint("Provider from upstream session data was not found.").
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
} }
@ -306,7 +308,7 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
validLDAP := s.ProviderType == psession.ProviderTypeLDAP && s.LDAP != nil && s.LDAP.UserDN != "" validLDAP := s.ProviderType == psession.ProviderTypeLDAP && s.LDAP != nil && s.LDAP.UserDN != ""
validAD := s.ProviderType == psession.ProviderTypeActiveDirectory && s.ActiveDirectory != nil && s.ActiveDirectory.UserDN != "" validAD := s.ProviderType == psession.ProviderTypeActiveDirectory && s.ActiveDirectory != nil && s.ActiveDirectory.UserDN != ""
if !(validLDAP || validAD) { if !(validLDAP || validAD) {
return errorsx.WithStack(errMissingUpstreamSessionInternalError) return errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
var additionalAttributes map[string]string var additionalAttributes map[string]string
@ -322,7 +324,7 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
return err return err
} }
if session.IDTokenClaims().AuthTime.IsZero() { if session.IDTokenClaims().AuthTime.IsZero() {
return errorsx.WithStack(errMissingUpstreamSessionInternalError) return errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
// run PerformRefresh // run PerformRefresh
groups, err := p.PerformRefresh(ctx, provider.StoredRefreshAttributes{ groups, err := p.PerformRefresh(ctx, provider.StoredRefreshAttributes{
@ -333,9 +335,9 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
AdditionalAttributes: additionalAttributes, AdditionalAttributes: additionalAttributes,
}) })
if err != nil { if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHint( return errUpstreamRefreshError().WithHint(
"Upstream refresh failed.").WithWrap(err). "Upstream refresh failed.").WithTrace(err).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
} }
// Replace the old value with the new value. // Replace the old value with the new value.
session.Fosite.Claims.Extra[oidc.DownstreamGroupsClaim] = groups session.Fosite.Claims.Extra[oidc.DownstreamGroupsClaim] = groups
@ -362,7 +364,7 @@ func findLDAPProviderByNameAndValidateUID(
for _, p := range providers { for _, p := range providers {
if p.GetName() == s.ProviderName { if p.GetName() == s.ProviderName {
if p.GetResourceUID() != s.ProviderUID { if p.GetResourceUID() != s.ProviderUID {
return nil, "", errorsx.WithStack(errUpstreamRefreshError.WithHint( return nil, "", errorsx.WithStack(errUpstreamRefreshError().WithHint(
"Provider from upstream session data has changed its resource UID since authentication."). "Provider from upstream session data has changed its resource UID since authentication.").
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
} }
@ -370,7 +372,7 @@ func findLDAPProviderByNameAndValidateUID(
} }
} }
return nil, "", errorsx.WithStack(errUpstreamRefreshError. return nil, "", errorsx.WithStack(errUpstreamRefreshError().
WithHint("Provider from upstream session data was not found."). WithHint("Provider from upstream session data was not found.").
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
} }
@ -378,15 +380,15 @@ func findLDAPProviderByNameAndValidateUID(
func getDownstreamUsernameFromPinnipedSession(session *psession.PinnipedSession) (string, error) { func getDownstreamUsernameFromPinnipedSession(session *psession.PinnipedSession) (string, error) {
extra := session.Fosite.Claims.Extra extra := session.Fosite.Claims.Extra
if extra == nil { if extra == nil {
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError) return "", errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
downstreamUsernameInterface := extra[oidc.DownstreamUsernameClaim] downstreamUsernameInterface := extra[oidc.DownstreamUsernameClaim]
if downstreamUsernameInterface == nil { if downstreamUsernameInterface == nil {
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError) return "", errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
downstreamUsername, ok := downstreamUsernameInterface.(string) downstreamUsername, ok := downstreamUsernameInterface.(string)
if !ok || len(downstreamUsername) == 0 { if !ok || len(downstreamUsername) == 0 {
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError) return "", errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
return downstreamUsername, nil return downstreamUsername, nil
} }
@ -394,22 +396,22 @@ func getDownstreamUsernameFromPinnipedSession(session *psession.PinnipedSession)
func getDownstreamGroupsFromPinnipedSession(session *psession.PinnipedSession) ([]string, error) { func getDownstreamGroupsFromPinnipedSession(session *psession.PinnipedSession) ([]string, error) {
extra := session.Fosite.Claims.Extra extra := session.Fosite.Claims.Extra
if extra == nil { if extra == nil {
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError) return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
downstreamGroupsInterface := extra[oidc.DownstreamGroupsClaim] downstreamGroupsInterface := extra[oidc.DownstreamGroupsClaim]
if downstreamGroupsInterface == nil { if downstreamGroupsInterface == nil {
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError) return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
downstreamGroupsInterfaceList, ok := downstreamGroupsInterface.([]interface{}) downstreamGroupsInterfaceList, ok := downstreamGroupsInterface.([]interface{})
if !ok { if !ok {
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError) return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
downstreamGroups := make([]string, 0, len(downstreamGroupsInterfaceList)) downstreamGroups := make([]string, 0, len(downstreamGroupsInterfaceList))
for _, downstreamGroupInterface := range downstreamGroupsInterfaceList { for _, downstreamGroupInterface := range downstreamGroupsInterfaceList {
downstreamGroup, ok := downstreamGroupInterface.(string) downstreamGroup, ok := downstreamGroupInterface.(string)
if !ok || len(downstreamGroup) == 0 { if !ok || len(downstreamGroup) == 0 {
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError) return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
downstreamGroups = append(downstreamGroups, downstreamGroup) downstreamGroups = append(downstreamGroups, downstreamGroup)
} }