Merge pull request #1092 from vmware-tanzu/remove_oryx_direct_dep
Remove direct dependency on ory/x
This commit is contained in:
commit
9c5adad062
2
go.mod
2
go.mod
@ -55,7 +55,6 @@ require (
|
||||
github.com/joshlf/go-acl v0.0.0-20200411065538-eae00ae38531
|
||||
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826
|
||||
github.com/ory/fosite v0.42.1
|
||||
github.com/ory/x v0.0.352
|
||||
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/sclevine/agouti v3.0.0+incompatible
|
||||
@ -141,6 +140,7 @@ require (
|
||||
github.com/ory/go-acc v0.2.7 // indirect
|
||||
github.com/ory/go-convenience v0.1.0 // indirect
|
||||
github.com/ory/viper v1.7.5 // indirect
|
||||
github.com/ory/x v0.0.214 // indirect
|
||||
github.com/pborman/uuid v1.2.1 // indirect
|
||||
github.com/pelletier/go-toml v1.9.4 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
|
@ -11,7 +11,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/ory/fosite"
|
||||
"github.com/ory/x/errorsx"
|
||||
errorsx "github.com/pkg/errors"
|
||||
"golang.org/x/oauth2"
|
||||
"k8s.io/apimachinery/pkg/util/sets"
|
||||
"k8s.io/apiserver/pkg/warning"
|
||||
@ -24,21 +24,6 @@ import (
|
||||
"go.pinniped.dev/internal/psession"
|
||||
)
|
||||
|
||||
var (
|
||||
errMissingUpstreamSessionInternalError = &fosite.RFC6749Error{
|
||||
ErrorField: "error",
|
||||
DescriptionField: "There was an internal server error.",
|
||||
HintField: "Required upstream data not found in session.",
|
||||
CodeField: http.StatusInternalServerError,
|
||||
}
|
||||
|
||||
errUpstreamRefreshError = &fosite.RFC6749Error{
|
||||
ErrorField: "error",
|
||||
DescriptionField: "Error during upstream refresh.",
|
||||
CodeField: http.StatusUnauthorized,
|
||||
}
|
||||
)
|
||||
|
||||
func NewHandler(
|
||||
idpLister oidc.UpstreamIdentityProvidersLister,
|
||||
oauthHelper fosite.OAuth2Provider,
|
||||
@ -91,17 +76,34 @@ func NewHandler(
|
||||
})
|
||||
}
|
||||
|
||||
func errMissingUpstreamSessionInternalError() *fosite.RFC6749Error {
|
||||
return &fosite.RFC6749Error{
|
||||
ErrorField: "error",
|
||||
DescriptionField: "There was an internal server error.",
|
||||
HintField: "Required upstream data not found in session.",
|
||||
CodeField: http.StatusInternalServerError,
|
||||
}
|
||||
}
|
||||
|
||||
func errUpstreamRefreshError() *fosite.RFC6749Error {
|
||||
return &fosite.RFC6749Error{
|
||||
ErrorField: "error",
|
||||
DescriptionField: "Error during upstream refresh.",
|
||||
CodeField: http.StatusUnauthorized,
|
||||
}
|
||||
}
|
||||
|
||||
func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error {
|
||||
session := accessRequest.GetSession().(*psession.PinnipedSession)
|
||||
|
||||
customSessionData := session.Custom
|
||||
if customSessionData == nil {
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError())
|
||||
}
|
||||
providerName := customSessionData.ProviderName
|
||||
providerUID := customSessionData.ProviderUID
|
||||
if providerUID == "" || providerName == "" {
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError())
|
||||
}
|
||||
|
||||
switch customSessionData.ProviderType {
|
||||
@ -112,14 +114,14 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
|
||||
case psession.ProviderTypeActiveDirectory:
|
||||
return upstreamLDAPRefresh(ctx, providerCache, session)
|
||||
default:
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError())
|
||||
}
|
||||
}
|
||||
|
||||
func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession, providerCache oidc.UpstreamIdentityProvidersLister) error {
|
||||
s := session.Custom
|
||||
if s.OIDC == nil {
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError())
|
||||
}
|
||||
|
||||
accessTokenStored := s.OIDC.UpstreamAccessToken != ""
|
||||
@ -127,7 +129,7 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
||||
|
||||
exactlyOneTokenStored := (accessTokenStored || refreshTokenStored) && !(accessTokenStored && refreshTokenStored)
|
||||
if !exactlyOneTokenStored {
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError())
|
||||
}
|
||||
|
||||
p, err := findOIDCProviderByNameAndValidateUID(s, providerCache)
|
||||
@ -142,9 +144,9 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
||||
if refreshTokenStored {
|
||||
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken)
|
||||
if err != nil {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHint(
|
||||
return errUpstreamRefreshError().WithHint(
|
||||
"Upstream refresh failed.",
|
||||
).WithWrap(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
).WithTrace(err).WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
|
||||
}
|
||||
} else {
|
||||
tokens = &oauth2.Token{AccessToken: s.OIDC.UpstreamAccessToken}
|
||||
@ -163,9 +165,9 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
||||
// least some providers do not include one, so we skip the nonce validation here (but not other validations).
|
||||
validatedTokens, err := p.ValidateTokenAndMergeWithUserInfo(ctx, tokens, "", hasIDTok, accessTokenStored)
|
||||
if err != nil {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh returned an invalid ID token or UserInfo response.").WithWrap(err).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
return errUpstreamRefreshError().WithHintf(
|
||||
"Upstream refresh returned an invalid ID token or UserInfo response.").WithTrace(err).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
|
||||
}
|
||||
mergedClaims := validatedTokens.IDToken.Claims
|
||||
|
||||
@ -184,9 +186,9 @@ func upstreamOIDCRefresh(ctx context.Context, session *psession.PinnipedSession,
|
||||
// and let any old groups memberships in the session remain.
|
||||
refreshedGroups, err := downstreamsession.GetGroupsFromUpstreamIDToken(p, mergedClaims)
|
||||
if err != nil {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh error while extracting groups claim.").WithWrap(err).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
return errUpstreamRefreshError().WithHintf(
|
||||
"Upstream refresh error while extracting groups claim.").WithTrace(err).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
|
||||
}
|
||||
if refreshedGroups != nil {
|
||||
oldGroups, err := getDownstreamGroupsFromPinnipedSession(session)
|
||||
@ -234,14 +236,14 @@ func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interfac
|
||||
|
||||
newSub, hasSub := getString(mergedClaims, oidc.IDTokenSubjectClaim)
|
||||
if !hasSub {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh failed.").WithWrap(errors.New("subject in upstream refresh not found")).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
return errUpstreamRefreshError().WithHintf(
|
||||
"Upstream refresh failed.").WithTrace(errors.New("subject in upstream refresh not found")).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
|
||||
}
|
||||
if s.OIDC.UpstreamSubject != newSub {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh failed.").WithWrap(errors.New("subject in upstream refresh does not match previous value")).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
return errUpstreamRefreshError().WithHintf(
|
||||
"Upstream refresh failed.").WithTrace(errors.New("subject in upstream refresh does not match previous value")).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
|
||||
}
|
||||
|
||||
newUsername, hasUsername := getString(mergedClaims, usernameClaimName)
|
||||
@ -249,18 +251,18 @@ func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interfac
|
||||
// It's possible that a username wasn't returned by the upstream provider during refresh,
|
||||
// but if it is, verify that it hasn't changed.
|
||||
if hasUsername && oldUsername != newUsername {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh failed.").WithWrap(errors.New("username in upstream refresh does not match previous value")).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
return errUpstreamRefreshError().WithHintf(
|
||||
"Upstream refresh failed.").WithTrace(errors.New("username in upstream refresh does not match previous value")).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
|
||||
}
|
||||
|
||||
newIssuer, hasIssuer := getString(mergedClaims, oidc.IDTokenIssuerClaim)
|
||||
// It's possible that an issuer wasn't returned by the upstream provider during refresh,
|
||||
// but if it is, verify that it hasn't changed.
|
||||
if hasIssuer && s.OIDC.UpstreamIssuer != newIssuer {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
|
||||
"Upstream refresh failed.").WithWrap(errors.New("issuer in upstream refresh does not match previous value")).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
return errUpstreamRefreshError().WithHintf(
|
||||
"Upstream refresh failed.").WithTrace(errors.New("issuer in upstream refresh does not match previous value")).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -278,13 +280,13 @@ func findOIDCProviderByNameAndValidateUID(
|
||||
for _, p := range providerCache.GetOIDCIdentityProviders() {
|
||||
if p.GetName() == s.ProviderName {
|
||||
if p.GetResourceUID() != s.ProviderUID {
|
||||
return nil, errorsx.WithStack(errUpstreamRefreshError.WithHint(
|
||||
return nil, errorsx.WithStack(errUpstreamRefreshError().WithHint(
|
||||
"Provider from upstream session data has changed its resource UID since authentication."))
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
return nil, errorsx.WithStack(errUpstreamRefreshError.
|
||||
return nil, errorsx.WithStack(errUpstreamRefreshError().
|
||||
WithHint("Provider from upstream session data was not found.").
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
@ -306,7 +308,7 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
|
||||
validLDAP := s.ProviderType == psession.ProviderTypeLDAP && s.LDAP != nil && s.LDAP.UserDN != ""
|
||||
validAD := s.ProviderType == psession.ProviderTypeActiveDirectory && s.ActiveDirectory != nil && s.ActiveDirectory.UserDN != ""
|
||||
if !(validLDAP || validAD) {
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError())
|
||||
}
|
||||
|
||||
var additionalAttributes map[string]string
|
||||
@ -322,7 +324,7 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
|
||||
return err
|
||||
}
|
||||
if session.IDTokenClaims().AuthTime.IsZero() {
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
return errorsx.WithStack(errMissingUpstreamSessionInternalError())
|
||||
}
|
||||
// run PerformRefresh
|
||||
groups, err := p.PerformRefresh(ctx, provider.StoredRefreshAttributes{
|
||||
@ -333,9 +335,9 @@ func upstreamLDAPRefresh(ctx context.Context, providerCache oidc.UpstreamIdentit
|
||||
AdditionalAttributes: additionalAttributes,
|
||||
})
|
||||
if err != nil {
|
||||
return errorsx.WithStack(errUpstreamRefreshError.WithHint(
|
||||
"Upstream refresh failed.").WithWrap(err).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
return errUpstreamRefreshError().WithHint(
|
||||
"Upstream refresh failed.").WithTrace(err).
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
|
||||
}
|
||||
// Replace the old value with the new value.
|
||||
session.Fosite.Claims.Extra[oidc.DownstreamGroupsClaim] = groups
|
||||
@ -362,7 +364,7 @@ func findLDAPProviderByNameAndValidateUID(
|
||||
for _, p := range providers {
|
||||
if p.GetName() == s.ProviderName {
|
||||
if p.GetResourceUID() != s.ProviderUID {
|
||||
return nil, "", errorsx.WithStack(errUpstreamRefreshError.WithHint(
|
||||
return nil, "", errorsx.WithStack(errUpstreamRefreshError().WithHint(
|
||||
"Provider from upstream session data has changed its resource UID since authentication.").
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
@ -370,7 +372,7 @@ func findLDAPProviderByNameAndValidateUID(
|
||||
}
|
||||
}
|
||||
|
||||
return nil, "", errorsx.WithStack(errUpstreamRefreshError.
|
||||
return nil, "", errorsx.WithStack(errUpstreamRefreshError().
|
||||
WithHint("Provider from upstream session data was not found.").
|
||||
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))
|
||||
}
|
||||
@ -378,15 +380,15 @@ func findLDAPProviderByNameAndValidateUID(
|
||||
func getDownstreamUsernameFromPinnipedSession(session *psession.PinnipedSession) (string, error) {
|
||||
extra := session.Fosite.Claims.Extra
|
||||
if extra == nil {
|
||||
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError())
|
||||
}
|
||||
downstreamUsernameInterface := extra[oidc.DownstreamUsernameClaim]
|
||||
if downstreamUsernameInterface == nil {
|
||||
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError())
|
||||
}
|
||||
downstreamUsername, ok := downstreamUsernameInterface.(string)
|
||||
if !ok || len(downstreamUsername) == 0 {
|
||||
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
return "", errorsx.WithStack(errMissingUpstreamSessionInternalError())
|
||||
}
|
||||
return downstreamUsername, nil
|
||||
}
|
||||
@ -394,22 +396,22 @@ func getDownstreamUsernameFromPinnipedSession(session *psession.PinnipedSession)
|
||||
func getDownstreamGroupsFromPinnipedSession(session *psession.PinnipedSession) ([]string, error) {
|
||||
extra := session.Fosite.Claims.Extra
|
||||
if extra == nil {
|
||||
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError())
|
||||
}
|
||||
downstreamGroupsInterface := extra[oidc.DownstreamGroupsClaim]
|
||||
if downstreamGroupsInterface == nil {
|
||||
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError())
|
||||
}
|
||||
downstreamGroupsInterfaceList, ok := downstreamGroupsInterface.([]interface{})
|
||||
if !ok {
|
||||
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError())
|
||||
}
|
||||
|
||||
downstreamGroups := make([]string, 0, len(downstreamGroupsInterfaceList))
|
||||
for _, downstreamGroupInterface := range downstreamGroupsInterfaceList {
|
||||
downstreamGroup, ok := downstreamGroupInterface.(string)
|
||||
if !ok || len(downstreamGroup) == 0 {
|
||||
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError)
|
||||
return nil, errorsx.WithStack(errMissingUpstreamSessionInternalError())
|
||||
}
|
||||
downstreamGroups = append(downstreamGroups, downstreamGroup)
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user