Remove our direct dependency on ory/x

ory/x has new releases very often, sometimes multiple times per week,
causing a lot of noise from dependabot. We were barely using it
directly, so replace our direct usages with equivalent code.
This commit is contained in:
Ryan Richard 2022-03-24 10:24:54 -07:00
parent 42bd385cbd
commit 48c5a625a5
3 changed files with 62 additions and 742 deletions

2
go.mod
View File

@ -55,7 +55,6 @@ require (
github.com/joshlf/go-acl v0.0.0-20200411065538-eae00ae38531
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
github.com/ory/fosite v0.42.1
github.com/ory/x v0.0.352
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8
github.com/pkg/errors v0.9.1
github.com/sclevine/agouti v3.0.0+incompatible
@ -141,6 +140,7 @@ require (
github.com/ory/go-acc v0.2.7 // indirect
github.com/ory/go-convenience v0.1.0 // indirect
github.com/ory/viper v1.7.5 // indirect
github.com/ory/x v0.0.214 // indirect
github.com/pborman/uuid v1.2.1 // indirect
github.com/pelletier/go-toml v1.9.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect

686
go.sum

File diff suppressed because it is too large Load Diff

View File

@ -11,7 +11,7 @@ import (
"net/http"
"github.com/ory/fosite"
"github.com/ory/x/errorsx"
errorsx "github.com/pkg/errors"
"golang.org/x/oauth2"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/apiserver/pkg/warning"
@ -24,21 +24,6 @@ import (
"go.pinniped.dev/internal/psession"
)
var (
errMissingUpstreamSessionInternalError = &fosite.RFC6749Error{
ErrorField: "error",
DescriptionField: "There was an internal server error.",
HintField: "Required upstream data not found in session.",
CodeField: http.StatusInternalServerError,
}
errUpstreamRefreshError = &fosite.RFC6749Error{
ErrorField: "error",
DescriptionField: "Error during upstream refresh.",
CodeField: http.StatusUnauthorized,
}
)
func NewHandler(
idpLister oidc.UpstreamIdentityProvidersLister,
oauthHelper fosite.OAuth2Provider,
@ -91,17 +76,34 @@ func NewHandler(
})
}
func errMissingUpstreamSessionInternalError() *fosite.RFC6749Error {
return &fosite.RFC6749Error{
ErrorField: "error",
DescriptionField: "There was an internal server error.",
HintField: "Required upstream data not found in session.",
CodeField: http.StatusInternalServerError,
}
}
func errUpstreamRefreshError() *fosite.RFC6749Error {
return &fosite.RFC6749Error{
ErrorField: "error",
DescriptionField: "Error during upstream refresh.",
CodeField: http.StatusUnauthorized,
}
}
func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error {
session := accessRequest.GetSession().(*psession.PinnipedSession)
customSessionData := session.Custom
if customSessionData == nil {
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
return errorsx.WithStack(errMissingUpstreamSessionInternalError())
}
providerName := customSessionData.ProviderName
providerUID := customSessionData.ProviderUID
if providerUID == "" || providerName == "" {
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
return errorsx.WithStack(errMissingUpstreamSessionInternalError())
}
switch customSessionData.ProviderType {
@ -112,14 +114,14 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
case psession.ProviderTypeActiveDirectory:
return upstreamLDAPRefresh(ctx, providerCache, session)
default:
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
return errorsx.WithStack(errMissingUpstreamSessionInternalError())
}
}
func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, providerCache oidc.UpstreamIdentityProvidersLister) error {
s := session.Custom
if s.OIDC == nil {
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
return errorsx.WithStack(errMissingUpstreamSessionInternalError())
}
accessTokenStored := s.OIDC.UpstreamAccessToken != ""
@ -127,7 +129,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
exactlyOneTokenStored := (accessTokenStored || refreshTokenStored) && !(accessTokenStored && refreshTokenStored)
if !exactlyOneTokenStored {
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
return errorsx.WithStack(errMissingUpstreamSessionInternalError())
}
p, err := findOIDCProviderByNameAndValidateUID(s, providerCache)
@ -142,9 +144,9 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
if refreshTokenStored {
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken)
if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHint(
return errUpstreamRefreshError().WithHint(
"Upstream refresh failed.",
).WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
).WithTrace(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
}
} else {
tokens = &oauth2.Token{AccessToken: s.OIDC.UpstreamAccessToken}
@ -163,9 +165,9 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
// least some providers do not include one, so we skip the nonce validation here (but not other validations).
validatedTokens, err := p.ValidateTokenAndMergeWithUserInfo(ctx, tokens, "", hasIDTok, accessTokenStored)
if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
"Upstream refresh returned an invalid ID token or UserInfo response.").WithWrap(err).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
return errUpstreamRefreshError().WithHintf(
"Upstream refresh returned an invalid ID token or UserInfo response.").WithTrace(err).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
}
mergedClaims := validatedTokens.IDToken.Claims
@ -184,9 +186,9 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
// and let any old groups memberships in the session remain.
refreshedGroups, err := downstreamsession.GetGroupsFromUpstreamIDToken(p, mergedClaims)
if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
"Upstream refresh error while extracting groups claim.").WithWrap(err).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
return errUpstreamRefreshError().WithHintf(
"Upstream refresh error while extracting groups claim.").WithTrace(err).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
}
if refreshedGroups != nil {
oldGroups, err := getDownstreamGroupsFromPinnipedSession(session)
@ -234,14 +236,14 @@ func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interfac
newSub, hasSub := getString(mergedClaims, oidc.IDTokenSubjectClaim)
if !hasSub {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
"Upstream refresh failed.").WithWrap(errors.New("subject in upstream refresh not found")).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
return errUpstreamRefreshError().WithHintf(
"Upstream refresh failed.").WithTrace(errors.New("subject in upstream refresh not found")).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
}
if s.OIDC.UpstreamSubject != newSub {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
"Upstream refresh failed.").WithWrap(errors.New("subject in upstream refresh does not match previous value")).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
return errUpstreamRefreshError().WithHintf(
"Upstream refresh failed.").WithTrace(errors.New("subject in upstream refresh does not match previous value")).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
}
newUsername, hasUsername := getString(mergedClaims, usernameClaimName)
@ -249,18 +251,18 @@ func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interfac
// It's possible that a username wasn't returned by the upstream provider during refresh,
// but if it is, verify that it hasn't changed.
if hasUsername && oldUsername != newUsername {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
"Upstream refresh failed.").WithWrap(errors.New("username in upstream refresh does not match previous value")).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
return errUpstreamRefreshError().WithHintf(
"Upstream refresh failed.").WithTrace(errors.New("username in upstream refresh does not match previous value")).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
}
newIssuer, hasIssuer := getString(mergedClaims, oidc.IDTokenIssuerClaim)
// It's possible that an issuer wasn't returned by the upstream provider during refresh,
// but if it is, verify that it hasn't changed.
if hasIssuer && s.OIDC.UpstreamIssuer != newIssuer {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
"Upstream refresh failed.").WithWrap(errors.New("issuer in upstream refresh does not match previous value")).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
return errUpstreamRefreshError().WithHintf(
"Upstream refresh failed.").WithTrace(errors.New("issuer in upstream refresh does not match previous value")).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
}
return nil
@ -278,13 +280,13 @@ func findOIDCProviderByNameAndValidateUID(
for _, p := range providerCache.GetOIDCIdentityProviders() {
if p.GetName() == s.ProviderName {
if p.GetResourceUID() != s.ProviderUID {
return nil, errorsx.WithStack(errUpstreamRefreshError.WithHint(
return nil, errorsx.WithStack(errUpstreamRefreshError().WithHint(
"Provider from upstream session data has changed its resource UID since authentication."))
}
return p, nil
}
}
return nil, errorsx.WithStack(errUpstreamRefreshError.
return nil, errorsx.WithStack(errUpstreamRefreshError().
WithHint("Provider from upstream session data was not found.").
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
}
@ -306,7 +308,7 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
validLDAP := s.ProviderType == psession.ProviderTypeLDAP && s.LDAP != nil && s.LDAP.UserDN != ""
validAD := s.ProviderType == psession.ProviderTypeActiveDirectory && s.ActiveDirectory != nil && s.ActiveDirectory.UserDN != ""
if !(validLDAP || validAD) {
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
return errorsx.WithStack(errMissingUpstreamSessionInternalError())
}
var additionalAttributes map[string]string
@ -322,7 +324,7 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
return err
}
if session.IDTokenClaims().AuthTime.IsZero() {
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
return errorsx.WithStack(errMissingUpstreamSessionInternalError())
}
// run PerformRefresh
groups, err := p.PerformRefresh(ctx, provider.StoredRefreshAttributes{
@ -333,9 +335,9 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
AdditionalAttributes: additionalAttributes,
})
if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHint(
"Upstream refresh failed.").WithWrap(err).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
return errUpstreamRefreshError().WithHint(
"Upstream refresh failed.").WithTrace(err).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
}
// Replace the old value with the new value.
session.Fosite.Claims.Extra[oidc.DownstreamGroupsClaim] = groups
@ -362,7 +364,7 @@ func findLDAPProviderByNameAndValidateUID(
for _, p := range providers {
if p.GetName() == s.ProviderName {
if p.GetResourceUID() != s.ProviderUID {
return nil, "", errorsx.WithStack(errUpstreamRefreshError.WithHint(
return nil, "", errorsx.WithStack(errUpstreamRefreshError().WithHint(
"Provider from upstream session data has changed its resource UID since authentication.").
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
}
@ -370,7 +372,7 @@ func findLDAPProviderByNameAndValidateUID(
}
}
return nil, "", errorsx.WithStack(errUpstreamRefreshError.
return nil, "", errorsx.WithStack(errUpstreamRefreshError().
WithHint("Provider from upstream session data was not found.").
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
}
@ -378,15 +380,15 @@ func findLDAPProviderByNameAndValidateUID(
func getDownstreamUsernameFromPinnipedSession(session *psession.PinnipedSession) (string, error) {
extra := session.Fosite.Claims.Extra
if extra == nil {
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError)
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError())
}
downstreamUsernameInterface := extra[oidc.DownstreamUsernameClaim]
if downstreamUsernameInterface == nil {
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError)
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError())
}
downstreamUsername, ok := downstreamUsernameInterface.(string)
if !ok || len(downstreamUsername) == 0 {
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError)
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError())
}
return downstreamUsername, nil
}
@ -394,22 +396,22 @@ func getDownstreamUsernameFromPinnipedSession(session *psession.PinnipedSession)
func getDownstreamGroupsFromPinnipedSession(session *psession.PinnipedSession) ([]string, error) {
extra := session.Fosite.Claims.Extra
if extra == nil {
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError)
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError())
}
downstreamGroupsInterface := extra[oidc.DownstreamGroupsClaim]
if downstreamGroupsInterface == nil {
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError)
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError())
}
downstreamGroupsInterfaceList, ok := downstreamGroupsInterface.([]interface{})
if !ok {
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError)
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError())
}
downstreamGroups := make([]string, 0, len(downstreamGroupsInterfaceList))
for _, downstreamGroupInterface := range downstreamGroupsInterfaceList {
downstreamGroup, ok := downstreamGroupInterface.(string)
if !ok || len(downstreamGroup) == 0 {
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError)
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError())
}
downstreamGroups = append(downstreamGroups, downstreamGroup)
}