Check that username and subject remain the same for ldap refresh

This commit is contained in:
Margo Crawford 2021-10-25 14:25:43 -07:00
parent 19281313dd
commit 7a58086040
8 changed files with 246 additions and 35 deletions

View File

@ -487,10 +487,7 @@ func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken
func downstreamSubjectFromUpstreamLDAP(ldapUpstream provider.UpstreamLDAPIdentityProviderI, authenticateResponse *authenticator.Response) string { func downstreamSubjectFromUpstreamLDAP(ldapUpstream provider.UpstreamLDAPIdentityProviderI, authenticateResponse *authenticator.Response) string {
ldapURL := *ldapUpstream.GetURL() ldapURL := *ldapUpstream.GetURL()
q := ldapURL.Query() return downstreamsession.DownstreamLDAPSubject(authenticateResponse.User.GetUID(), ldapURL)
q.Set(oidc.IDTokenSubjectClaim, authenticateResponse.User.GetUID())
ldapURL.RawQuery = q.Encode()
return ldapURL.String()
} }
func userDNFromAuthenticatedResponse(authenticatedResponse *authenticator.Response) string { func userDNFromAuthenticatedResponse(authenticatedResponse *authenticator.Response) string {

View File

@ -169,6 +169,13 @@ func extractStringClaimValue(claimName string, upstreamIDPName string, idTokenCl
return valueAsString, nil return valueAsString, nil
} }
func DownstreamLDAPSubject(uid string, ldapURL url.URL) string {
q := ldapURL.Query()
q.Set(oidc.IDTokenSubjectClaim, uid)
ldapURL.RawQuery = q.Encode()
return ldapURL.String()
}
func downstreamSubjectFromUpstreamOIDC(upstreamIssuerAsString string, upstreamSubject string) string { func downstreamSubjectFromUpstreamOIDC(upstreamIssuerAsString string, upstreamSubject string) string {
return fmt.Sprintf("%s?%s=%s", upstreamIssuerAsString, oidc.IDTokenSubjectClaim, url.QueryEscape(upstreamSubject)) return fmt.Sprintf("%s?%s=%s", upstreamIssuerAsString, oidc.IDTokenSubjectClaim, url.QueryEscape(upstreamSubject))
} }

View File

@ -90,7 +90,7 @@ type UpstreamLDAPIdentityProviderI interface {
authenticators.UserAuthenticator authenticators.UserAuthenticator
// PerformRefresh performs a refresh against the upstream LDAP identity provider // PerformRefresh performs a refresh against the upstream LDAP identity provider
PerformRefresh(ctx context.Context, userDN string) error PerformRefresh(ctx context.Context, userDN string, expectedUsername string, expectedSubject string) error
} }
type DynamicUpstreamIDPProvider interface { type DynamicUpstreamIDPProvider interface {

View File

@ -6,6 +6,7 @@ package token
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"github.com/ory/fosite" "github.com/ory/fosite"
@ -75,6 +76,18 @@ func NewHandler(
func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error { func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error {
session := accessRequest.GetSession().(*psession.PinnipedSession) session := accessRequest.GetSession().(*psession.PinnipedSession)
fositeSession := session.Fosite
if fositeSession == nil {
return fmt.Errorf("fosite session not found")
}
claims := fositeSession.Claims
if claims == nil {
return fmt.Errorf("fosite session claims not found")
}
extra := claims.Extra
downstreamUsername := extra["username"].(string)
downstreamSubject := claims.Subject
customSessionData := session.Custom customSessionData := session.Custom
if customSessionData == nil { if customSessionData == nil {
return errorsx.WithStack(errMissingUpstreamSessionInternalError) return errorsx.WithStack(errMissingUpstreamSessionInternalError)
@ -89,9 +102,9 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
case psession.ProviderTypeOIDC: case psession.ProviderTypeOIDC:
return upstreamOIDCRefresh(ctx, customSessionData, providerCache) return upstreamOIDCRefresh(ctx, customSessionData, providerCache)
case psession.ProviderTypeLDAP: case psession.ProviderTypeLDAP:
return upstreamLDAPRefresh(ctx, customSessionData, providerCache) return upstreamLDAPRefresh(ctx, customSessionData, providerCache, downstreamUsername, downstreamSubject)
case psession.ProviderTypeActiveDirectory: case psession.ProviderTypeActiveDirectory:
return upstreamLDAPRefresh(ctx, customSessionData, providerCache) return upstreamLDAPRefresh(ctx, customSessionData, providerCache, downstreamUsername, downstreamSubject)
default: default:
return errorsx.WithStack(errMissingUpstreamSessionInternalError) return errorsx.WithStack(errMissingUpstreamSessionInternalError)
} }
@ -164,19 +177,19 @@ func findOIDCProviderByNameAndValidateUID(
WithHintf("Provider %q of type %q from upstream session data was not found.", s.ProviderName, s.ProviderType)) WithHintf("Provider %q of type %q from upstream session data was not found.", s.ProviderName, s.ProviderType))
} }
func upstreamLDAPRefresh(ctx context.Context, s *psession.CustomSessionData, providerCache oidc.UpstreamIdentityProvidersLister) error { func upstreamLDAPRefresh(ctx context.Context, s *psession.CustomSessionData, providerCache oidc.UpstreamIdentityProvidersLister, username string, subject string) error {
plog.Warning("refreshing upstream")
// if you have neither a valid ldap session config nor a valid active directory session config // if you have neither a valid ldap session config nor a valid active directory session config
if (s.LDAP == nil || s.LDAP.UserDN == "") && (s.ActiveDirectory == nil || s.ActiveDirectory.UserDN == "") { if (s.LDAP == nil || s.LDAP.UserDN == "") && (s.ActiveDirectory == nil || s.ActiveDirectory.UserDN == "") {
return errorsx.WithStack(errMissingUpstreamSessionInternalError) return errorsx.WithStack(errMissingUpstreamSessionInternalError)
} }
plog.Warning("going to find provider", "provider", s.ProviderName)
// get ldap/ad provider out of cache // get ldap/ad provider out of cache
p, dn, _ := findLDAPProviderByNameAndValidateUID(s, providerCache) p, dn, err := findLDAPProviderByNameAndValidateUID(s, providerCache)
// TODO error checking if err != nil {
return err
}
// run PerformRefresh // run PerformRefresh
err := p.PerformRefresh(ctx, dn) err = p.PerformRefresh(ctx, dn, username, subject)
if err != nil { if err != nil {
return errorsx.WithStack(errUpstreamRefreshError.WithHintf( return errorsx.WithStack(errUpstreamRefreshError.WithHintf(
"Upstream refresh failed using provider %q of type %q.", "Upstream refresh failed using provider %q of type %q.",

View File

@ -935,6 +935,8 @@ func TestRefreshGrant(t *testing.T) {
args: &oidctestutil.PerformRefreshArgs{ args: &oidctestutil.PerformRefreshArgs{
Ctx: nil, Ctx: nil,
DN: ldapUpstreamDN, DN: ldapUpstreamDN,
ExpectedSubject: goodSubject,
ExpectedUsername: goodUsername,
}, },
} }
} }
@ -945,6 +947,8 @@ func TestRefreshGrant(t *testing.T) {
args: &oidctestutil.PerformRefreshArgs{ args: &oidctestutil.PerformRefreshArgs{
Ctx: nil, Ctx: nil,
DN: activeDirectoryUpstreamDN, DN: activeDirectoryUpstreamDN,
ExpectedSubject: goodSubject,
ExpectedUsername: goodUsername,
}, },
} }
} }
@ -1796,6 +1800,7 @@ func TestRefreshGrant(t *testing.T) {
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{ want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: happyLDAPUpstreamRefreshCall(),
wantStatus: http.StatusUnauthorized, wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(` wantErrorResponseBody: here.Doc(`
{ {
@ -1837,6 +1842,7 @@ func TestRefreshGrant(t *testing.T) {
}, },
refreshRequest: refreshRequestInputs{ refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{ want: tokenEndpointResponseExpectedValues{
wantUpstreamRefreshCall: happyActiveDirectoryUpstreamRefreshCall(),
wantStatus: http.StatusUnauthorized, wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(` wantErrorResponseBody: here.Doc(`
{ {
@ -1847,6 +1853,78 @@ func TestRefreshGrant(t *testing.T) {
}, },
}, },
}, },
{
name: "upstream ldap idp not found",
idps: oidctestutil.NewUpstreamIDPListerBuilder(),
authcodeExchange: authcodeExchangeInputs{
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
customSessionData: &psession.CustomSessionData{
ProviderUID: ldapUpstreamResourceUID,
ProviderName: ldapUpstreamName,
ProviderType: ldapUpstreamType,
LDAP: &psession.LDAPSessionData{
UserDN: ldapUpstreamDN,
},
},
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(
&psession.CustomSessionData{
ProviderUID: ldapUpstreamResourceUID,
ProviderName: ldapUpstreamName,
ProviderType: ldapUpstreamType,
LDAP: &psession.LDAPSessionData{
UserDN: ldapUpstreamDN,
},
},
),
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(`
{
"error": "error",
"error_description": "Error during upstream refresh. Provider 'some-ldap-idp' of type 'ldap' from upstream session data was not found."
}
`),
},
},
},
{
name: "upstream active directory idp not found",
idps: oidctestutil.NewUpstreamIDPListerBuilder(),
authcodeExchange: authcodeExchangeInputs{
modifyAuthRequest: func(r *http.Request) { r.Form.Set("scope", "openid offline_access") },
customSessionData: &psession.CustomSessionData{
ProviderUID: activeDirectoryUpstreamResourceUID,
ProviderName: activeDirectoryUpstreamName,
ProviderType: activeDirectoryUpstreamType,
ActiveDirectory: &psession.ActiveDirectorySessionData{
UserDN: activeDirectoryUpstreamDN,
},
},
want: happyAuthcodeExchangeTokenResponseForOpenIDAndOfflineAccess(
&psession.CustomSessionData{
ProviderUID: activeDirectoryUpstreamResourceUID,
ProviderName: activeDirectoryUpstreamName,
ProviderType: activeDirectoryUpstreamType,
ActiveDirectory: &psession.ActiveDirectorySessionData{
UserDN: activeDirectoryUpstreamDN,
},
},
),
},
refreshRequest: refreshRequestInputs{
want: tokenEndpointResponseExpectedValues{
wantStatus: http.StatusUnauthorized,
wantErrorResponseBody: here.Doc(`
{
"error": "error",
"error_description": "Error during upstream refresh. Provider 'some-ad-idp' of type 'activedirectory' from upstream session data was not found."
}
`),
},
},
},
} }
for _, test := range tests { for _, test := range tests {
test := test test := test

View File

@ -64,6 +64,8 @@ type PerformRefreshArgs struct {
Ctx context.Context Ctx context.Context
RefreshToken string RefreshToken string
DN string DN string
ExpectedUsername string
ExpectedSubject string
} }
// ValidateTokenArgs is used to spy on calls to // ValidateTokenArgs is used to spy on calls to
@ -102,7 +104,7 @@ func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL {
return u.URL return u.URL
} }
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, userDN string) error { func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, userDN string, expectedUsername string, expectedSubject string) error {
if u.performRefreshArgs == nil { if u.performRefreshArgs == nil {
u.performRefreshArgs = make([]*PerformRefreshArgs, 0) u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
} }
@ -110,6 +112,8 @@ func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, u
u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{ u.performRefreshArgs = append(u.performRefreshArgs, &PerformRefreshArgs{
Ctx: ctx, Ctx: ctx,
DN: userDN, DN: userDN,
ExpectedUsername: expectedUsername,
ExpectedSubject: expectedSubject,
}) })
if u.PerformRefreshErr != nil { if u.PerformRefreshErr != nil {
return u.PerformRefreshErr return u.PerformRefreshErr

View File

@ -27,6 +27,7 @@ import (
"go.pinniped.dev/internal/authenticators" "go.pinniped.dev/internal/authenticators"
"go.pinniped.dev/internal/endpointaddr" "go.pinniped.dev/internal/endpointaddr"
"go.pinniped.dev/internal/oidc/downstreamsession"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
) )
@ -169,7 +170,7 @@ func (p *Provider) GetConfig() ProviderConfig {
return p.c return p.c
} }
func (p *Provider) PerformRefresh(ctx context.Context, userDN string) error { func (p *Provider) PerformRefresh(ctx context.Context, userDN string, expectedUsername string, expectedSubject string) error {
t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetName()}) t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetName()})
defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches
search := p.refreshUserSearchRequest(userDN) search := p.refreshUserSearchRequest(userDN)
@ -201,6 +202,30 @@ func (p *Provider) PerformRefresh(ctx context.Context, userDN string) error {
) )
} }
userEntry := searchResult.Entries[0]
if len(userEntry.DN) == 0 {
return fmt.Errorf(`searching for user with original DN "%s" resulted in search result without DN`, userDN)
}
newUsername, err := p.getSearchResultAttributeValue(p.c.UserSearch.UsernameAttribute, userEntry, userDN)
if err != nil {
return err // TODO test having no values or more than one maybe
}
if newUsername != expectedUsername {
return fmt.Errorf(`searching for user "%s" returned a different username than the previous value. expected: "%s", actual: "%s"`,
userDN, expectedUsername, newUsername,
)
}
newUID, err := p.getSearchResultAttributeRawValueEncoded(p.c.UserSearch.UIDAttribute, userEntry, userDN)
if err != nil {
return err // TODO test
}
newSubject := downstreamsession.DownstreamLDAPSubject(newUID, *p.GetURL())
if newSubject != expectedSubject {
return fmt.Errorf(`searching for user "%s" produced a different subject than the previous value. expected: "%s", actual: "%s"`, userDN, expectedSubject, newSubject)
}
// do nothing. if we got exactly one search result back then that means the user // do nothing. if we got exactly one search result back then that means the user
// still exists. // still exists.
return nil return nil
@ -616,7 +641,7 @@ func (p *Provider) refreshUserSearchRequest(dn string) *ldap.SearchRequest {
TimeLimit: 90, TimeLimit: 90,
TypesOnly: false, TypesOnly: false,
Filter: "(objectClass=*)", // we already have the dn, so the filter doesn't matter Filter: "(objectClass=*)", // we already have the dn, so the filter doesn't matter
Attributes: []string{}, // TODO this will need to include some other AD attributes Attributes: p.userSearchRequestedAttributes(), // TODO this will need to include some other AD attributes
Controls: nil, // this could be used to enable paging, but we're already limiting the result max size Controls: nil, // this could be used to enable paging, but we're already limiting the result max size
} }
} }

View File

@ -1223,7 +1223,7 @@ func TestUpstreamRefresh(t *testing.T) {
TimeLimit: 90, TimeLimit: 90,
TypesOnly: false, TypesOnly: false,
Filter: "(objectClass=*)", Filter: "(objectClass=*)",
Attributes: []string{}, Attributes: []string{testUserSearchUsernameAttribute, testUserSearchUIDAttribute},
Controls: nil, // don't need paging because we set the SizeLimit so small Controls: nil, // don't need paging because we set the SizeLimit so small
} }
@ -1231,7 +1231,17 @@ func TestUpstreamRefresh(t *testing.T) {
Entries: []*ldap.Entry{ Entries: []*ldap.Entry{
{ {
DN: testUserSearchResultDNValue, DN: testUserSearchResultDNValue,
Attributes: []*ldap.EntryAttribute{}, Attributes: []*ldap.EntryAttribute{
{
Name: testUserSearchUsernameAttribute,
Values: []string{testUserSearchResultUsernameAttributeValue},
},
{
Name: testUserSearchUIDAttribute,
Values: []string{testUserSearchResultUIDAttributeValue},
ByteValues: [][]byte{[]byte(testUserSearchResultUIDAttributeValue)},
},
},
}, },
}, },
} }
@ -1245,6 +1255,8 @@ func TestUpstreamRefresh(t *testing.T) {
BindPassword: testBindPassword, BindPassword: testBindPassword,
UserSearch: UserSearchConfig{ UserSearch: UserSearchConfig{
Base: testUserSearchBase, Base: testUserSearchBase,
UIDAttribute: testUserSearchUIDAttribute,
UsernameAttribute: testUserSearchUsernameAttribute,
}, },
} }
@ -1322,6 +1334,80 @@ func TestUpstreamRefresh(t *testing.T) {
}, },
wantErr: "searching for user \"some-upstream-user-dn\" resulted in 2 search results, but expected 1 result", wantErr: "searching for user \"some-upstream-user-dn\" resulted in 2 search results, but expected 1 result",
}, },
{
name: "search result has wrong uid",
providerConfig: providerConfig,
setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{
Entries: []*ldap.Entry{
{
DN: testUserSearchResultDNValue,
Attributes: []*ldap.EntryAttribute{
{
Name: testUserSearchUsernameAttribute,
Values: []string{testUserSearchResultUsernameAttributeValue},
},
{
Name: testUserSearchUIDAttribute,
Values: []string{"wrong-uid"},
ByteValues: [][]byte{[]byte("wrong-uid")},
},
},
},
},
}, nil).Times(1)
conn.EXPECT().Close().Times(1)
},
wantErr: "searching for user \"some-upstream-user-dn\" produced a different subject than the previous value. expected: \"ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=c29tZS11cHN0cmVhbS11aWQtdmFsdWU\", actual: \"ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=d3JvbmctdWlk\"",
},
{
name: "search result has wrong username",
providerConfig: providerConfig,
setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{
Entries: []*ldap.Entry{
{
DN: testUserSearchResultDNValue,
Attributes: []*ldap.EntryAttribute{
{
Name: testUserSearchUsernameAttribute,
Values: []string{"wrong-username"},
},
},
},
},
}, nil).Times(1)
conn.EXPECT().Close().Times(1)
},
wantErr: "searching for user \"some-upstream-user-dn\" returned a different username than the previous value. expected: \"some-upstream-username-value\", actual: \"wrong-username\"",
},
{
name: "search result has no dn",
providerConfig: providerConfig,
setupMocks: func(conn *mockldapconn.MockConn) {
conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1)
conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{
Entries: []*ldap.Entry{
{
Attributes: []*ldap.EntryAttribute{
{
Name: testUserSearchUsernameAttribute,
Values: []string{testUserSearchResultUsernameAttributeValue},
},
{
Name: testUserSearchUIDAttribute,
Values: []string{testUserSearchResultUIDAttributeValue},
},
},
},
},
}, nil).Times(1)
conn.EXPECT().Close().Times(1)
},
wantErr: "searching for user with original DN \"some-upstream-user-dn\" resulted in search result without DN",
},
} }
for _, test := range tests { for _, test := range tests {
@ -1347,9 +1433,10 @@ func TestUpstreamRefresh(t *testing.T) {
}) })
provider := New(*providerConfig) provider := New(*providerConfig)
err := provider.PerformRefresh(context.Background(), testUserSearchResultDNValue) subject := "ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=c29tZS11cHN0cmVhbS11aWQtdmFsdWU"
err := provider.PerformRefresh(context.Background(), testUserSearchResultDNValue, testUserSearchResultUsernameAttributeValue, subject)
if tt.wantErr != "" { if tt.wantErr != "" {
require.NotNil(t, err) require.Error(t, err)
require.Equal(t, tt.wantErr, err.Error()) require.Equal(t, tt.wantErr, err.Error())
} else { } else {
require.NoError(t, err) require.NoError(t, err)