First draft of implementation of multiple IDPs support

This commit is contained in:
Ryan Richard 2023-05-08 14:07:38 -07:00
parent 1a53b4daea
commit 7af75dfe3c
44 changed files with 1465 additions and 626 deletions

View File

@ -1,6 +1,6 @@
#!/usr/bin/env bash #!/usr/bin/env bash
# Copyright 2021-2022 the Pinniped contributors. All Rights Reserved. # Copyright 2021-2023 the Pinniped contributors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# #
@ -81,7 +81,7 @@ while (("$#")); do
done done
if [[ "$use_oidc_upstream" == "no" && "$use_ldap_upstream" == "no" && "$use_ad_upstream" == "no" ]]; then if [[ "$use_oidc_upstream" == "no" && "$use_ldap_upstream" == "no" && "$use_ad_upstream" == "no" ]]; then
log_error "Error: Please use --oidc, --ldap, or --ad to specify which type of upstream identity provider(s) you would like" log_error "Error: Please use --oidc, --ldap, or --ad to specify which type(s) of upstream identity provider(s) you would like. May use one or multiple."
exit 1 exit 1
fi fi
@ -103,42 +103,6 @@ audience="my-workload-cluster-$(openssl rand -hex 4)"
issuer_host="pinniped-supervisor-clusterip.supervisor.svc.cluster.local" issuer_host="pinniped-supervisor-clusterip.supervisor.svc.cluster.local"
issuer="https://$issuer_host/some/path" issuer="https://$issuer_host/some/path"
# Create a CA and TLS serving certificates for the Supervisor.
step certificate create \
"Supervisor CA" "$root_ca_crt_path" "$root_ca_key_path" \
--profile root-ca \
--no-password --insecure --force
step certificate create \
"$issuer_host" "$tls_crt_path" "$tls_key_path" \
--profile leaf \
--not-after 8760h \
--ca "$root_ca_crt_path" --ca-key "$root_ca_key_path" \
--no-password --insecure --force
# Put the TLS certificate into a Secret for the Supervisor.
kubectl create secret tls -n "$PINNIPED_TEST_SUPERVISOR_NAMESPACE" my-federation-domain-tls --cert "$tls_crt_path" --key "$tls_key_path" \
--dry-run=client --output yaml | kubectl apply -f -
# Make a FederationDomain using the TLS Secret from above.
cat <<EOF | kubectl apply --namespace "$PINNIPED_TEST_SUPERVISOR_NAMESPACE" -f -
apiVersion: config.supervisor.pinniped.dev/v1alpha1
kind: FederationDomain
metadata:
name: my-federation-domain
spec:
issuer: $issuer
tls:
secretName: my-federation-domain-tls
EOF
echo "Waiting for FederationDomain to initialize..."
# Sleeping is a race, but that's probably good enough for the purposes of this script.
sleep 5
# Test that the federation domain is working before we proceed.
echo "Fetching FederationDomain discovery info..."
https_proxy="$PINNIPED_TEST_PROXY" curl -fLsS --cacert "$root_ca_crt_path" "$issuer/.well-known/openid-configuration" | jq .
if [[ "$use_oidc_upstream" == "yes" ]]; then if [[ "$use_oidc_upstream" == "yes" ]]; then
# Make an OIDCIdentityProvider which uses Dex to provide identity. # Make an OIDCIdentityProvider which uses Dex to provide identity.
cat <<EOF | kubectl apply --namespace "$PINNIPED_TEST_SUPERVISOR_NAMESPACE" -f - cat <<EOF | kubectl apply --namespace "$PINNIPED_TEST_SUPERVISOR_NAMESPACE" -f -
@ -254,6 +218,146 @@ EOF
--dry-run=client --output yaml | kubectl apply -f - --dry-run=client --output yaml | kubectl apply -f -
fi fi
# Create a CA and TLS serving certificates for the Supervisor's FederationDomain.
if [[ ! -f "$root_ca_crt_path" ]]; then
step certificate create \
"Supervisor CA" "$root_ca_crt_path" "$root_ca_key_path" \
--profile root-ca \
--no-password --insecure --force
fi
if [[ ! -f "$tls_crt_path" || ! -f "$tls_key_path" ]]; then
step certificate create \
"$issuer_host" "$tls_crt_path" "$tls_key_path" \
--profile leaf \
--not-after 8760h \
--ca "$root_ca_crt_path" --ca-key "$root_ca_key_path" \
--no-password --insecure --force
fi
# Put the TLS certificate into a Secret for the Supervisor's FederationDomain.
kubectl create secret tls -n "$PINNIPED_TEST_SUPERVISOR_NAMESPACE" my-federation-domain-tls --cert "$tls_crt_path" --key "$tls_key_path" \
--dry-run=client --output yaml | kubectl apply -f -
# Variable that will be used to build up the "identityProviders" yaml for the FederationDomain.
fd_idps=""
if [[ "$use_oidc_upstream" == "yes" ]]; then
# Indenting the heredoc by 4 spaces to make it indented the correct amount in the FederationDomain below.
fd_idps="${fd_idps}$(
cat <<EOF
- displayName: "My OIDC IDP"
objectRef:
apiGroup: idp.supervisor.pinniped.dev
kind: OIDCIdentityProvider
name: my-oidc-provider
transforms:
expressions:
- type: username/v1
expression: '"oidc:" + username'
- type: groups/v1 # the pinny user doesn't belong to any groups in Dex, so this isn't strictly needed, but doesn't hurt
expression: 'groups.map(group, "oidc:" + group)'
examples:
- username: ryan@example.com
groups: [ a, b ]
expects:
username: oidc:ryan@example.com
groups: [ oidc:a, oidc:b ]
EOF
)"
fi
if [[ "$use_ldap_upstream" == "yes" ]]; then
# Indenting the heredoc by 4 spaces to make it indented the correct amount in the FederationDomain below.
fd_idps="${fd_idps}$(
cat <<EOF
- displayName: "My LDAP IDP"
objectRef:
apiGroup: idp.supervisor.pinniped.dev
kind: LDAPIdentityProvider
name: my-ldap-provider
transforms: # these are contrived to exercise all the available features
constants:
- name: prefix
type: string
stringValue: "ldap:"
- name: onlyIncludeGroupsWithThisPrefix
type: string
stringValue: "ball-" # pinny belongs to ball-game-players in openldap
- name: mustBelongToOneOfThese
type: stringList
stringListValue: [ ball-admins, seals ] # pinny belongs to seals in openldap
- name: additionalAdmins
type: stringList
stringListValue: [ pinny.ldap@example.com, ryan@example.com ] # pinny's email address in openldap
expressions:
- type: policy/v1
expression: 'groups.exists(g, g in strListConst.mustBelongToOneOfThese)'
message: "Only users in certain kube groups are allowed to authenticate"
- type: groups/v1
expression: 'username in strListConst.additionalAdmins ? groups + ["ball-admins"] : groups'
- type: groups/v1
expression: 'groups.filter(group, group.startsWith(strConst.onlyIncludeGroupsWithThisPrefix))'
- type: username/v1
expression: 'strConst.prefix + username'
- type: groups/v1
expression: 'groups.map(group, strConst.prefix + group)'
examples:
- username: ryan@example.com
groups: [ ball-developers, seals, non-ball-group ] # allowed to auth because belongs to seals
expects:
username: ldap:ryan@example.com
groups: [ ldap:ball-developers, ldap:ball-admins ] # gets ball-admins because of username, others dropped because they lack "ball-" prefix
- username: someone_else@example.com
groups: [ ball-developers, ball-admins, non-ball-group ] # allowed to auth because belongs to ball-admins
expects:
username: ldap:someone_else@example.com
groups: [ ldap:ball-developers, ldap:ball-admins ] # seals dropped because it lacks prefix
- username: paul@example.com
groups: [ not-ball-admins-group, not-seals-group ] # reject because does not belong to any of the required groups
expects:
rejected: true
message: "Only users in certain kube groups are allowed to authenticate"
EOF
)"
fi
if [[ "$use_ad_upstream" == "yes" ]]; then
# Indenting the heredoc by 4 spaces to make it indented the correct amount in the FederationDomain below.
fd_idps="${fd_idps}$(
cat <<EOF
- displayName: "My AD IDP"
objectRef:
apiGroup: idp.supervisor.pinniped.dev
kind: ActiveDirectoryIdentityProvider
name: my-ad-provider
EOF
)"
fi
# Make a FederationDomain using the TLS Secret and identity providers from above.
cat <<EOF | kubectl apply --namespace "$PINNIPED_TEST_SUPERVISOR_NAMESPACE" -f -
apiVersion: config.supervisor.pinniped.dev/v1alpha1
kind: FederationDomain
metadata:
name: my-federation-domain
spec:
issuer: $issuer
tls:
secretName: my-federation-domain-tls
identityProviders:${fd_idps}
EOF
echo "Waiting for FederationDomain to initialize or update..."
# Sleeping is a race, but that's probably good enough for the purposes of this script.
sleep 5
# Test that the federation domain is working before we proceed.
echo "Fetching FederationDomain discovery info via command: https_proxy=\"$PINNIPED_TEST_PROXY\" curl -fLsS --cacert \"$root_ca_crt_path\" \"$issuer/.well-known/openid-configuration\""
https_proxy="$PINNIPED_TEST_PROXY" curl -fLsS --cacert "$root_ca_crt_path" "$issuer/.well-known/openid-configuration" | jq .
if [[ "$OSTYPE" == "darwin"* ]]; then if [[ "$OSTYPE" == "darwin"* ]]; then
certificateAuthorityData=$(cat "$root_ca_crt_path" | base64) certificateAuthorityData=$(cat "$root_ca_crt_path" | base64)
else else
@ -275,7 +379,7 @@ spec:
certificateAuthorityData: $certificateAuthorityData certificateAuthorityData: $certificateAuthorityData
EOF EOF
echo "Waiting for JWTAuthenticator to initialize..." echo "Waiting for JWTAuthenticator to initialize or update..."
# Sleeping is a race, but that's probably good enough for the purposes of this script. # Sleeping is a race, but that's probably good enough for the purposes of this script.
sleep 5 sleep 5
@ -288,12 +392,24 @@ while [[ -z "$(kubectl get credentialissuer pinniped-concierge-config -o=jsonpat
sleep 2 sleep 2
done done
# Use the CLI to get the kubeconfig. Tell it that you don't want the browser to automatically open for logins. # Use the CLI to get the kubeconfig. Tell it that you don't want the browser to automatically open for browser-based
# flows so we can open our own browser with the proxy settings. Generate a kubeconfig for each IDP.
flow_arg="" flow_arg=""
if [[ -n "$use_flow" ]]; then if [[ -n "$use_flow" ]]; then
flow_arg="--upstream-identity-provider-flow $use_flow" flow_arg="--upstream-identity-provider-flow $use_flow"
fi fi
https_proxy="$PINNIPED_TEST_PROXY" no_proxy="127.0.0.1" ./pinniped get kubeconfig --oidc-skip-browser $flow_arg >kubeconfig if [[ "$use_oidc_upstream" == "yes" ]]; then
https_proxy="$PINNIPED_TEST_PROXY" no_proxy="127.0.0.1" \
./pinniped get kubeconfig --oidc-skip-browser $flow_arg --upstream-identity-provider-type oidc >kubeconfig-oidc.yaml
fi
if [[ "$use_ldap_upstream" == "yes" ]]; then
https_proxy="$PINNIPED_TEST_PROXY" no_proxy="127.0.0.1" \
./pinniped get kubeconfig --oidc-skip-browser $flow_arg --upstream-identity-provider-type ldap >kubeconfig-ldap.yaml
fi
if [[ "$use_ad_upstream" == "yes" ]]; then
https_proxy="$PINNIPED_TEST_PROXY" no_proxy="127.0.0.1" \
./pinniped get kubeconfig --oidc-skip-browser $flow_arg --upstream-identity-provider-type activedirectory >kubeconfig-ad.yaml
fi
# Clear the local CLI cache to ensure that the kubectl command below will need to perform a fresh login. # Clear the local CLI cache to ensure that the kubectl command below will need to perform a fresh login.
rm -f "$HOME/.config/pinniped/sessions.yaml" rm -f "$HOME/.config/pinniped/sessions.yaml"
@ -304,37 +420,48 @@ echo "Ready! 🚀"
if [[ "$use_oidc_upstream" == "yes" || "$use_flow" == "browser_authcode" ]]; then if [[ "$use_oidc_upstream" == "yes" || "$use_flow" == "browser_authcode" ]]; then
echo echo
echo "To be able to access the login URL shown below, start Chrome like this:" echo "To be able to access the Supervisor URL during login, start Chrome like this:"
echo " open -a \"Google Chrome\" --args --proxy-server=\"$PINNIPED_TEST_PROXY\"" echo " open -a \"Google Chrome\" --args --proxy-server=\"$PINNIPED_TEST_PROXY\""
echo "Note that Chrome must be fully quit before being started with --proxy-server." echo "Note that Chrome must be fully quit before being started with --proxy-server."
echo "Then open the login URL shown below in that new Chrome window." echo "Then open the login URL shown below in that new Chrome window."
echo echo
echo "When prompted for username and password, use these values:" echo "When prompted for username and password, use these values:"
echo
fi fi
if [[ "$use_oidc_upstream" == "yes" ]]; then if [[ "$use_oidc_upstream" == "yes" ]]; then
echo " Username: $PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_USERNAME" echo " OIDC Username: $PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_USERNAME"
echo " Password: $PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_PASSWORD" echo " OIDC Password: $PINNIPED_TEST_SUPERVISOR_UPSTREAM_OIDC_PASSWORD"
echo
fi fi
if [[ "$use_ldap_upstream" == "yes" ]]; then if [[ "$use_ldap_upstream" == "yes" ]]; then
echo " Username: $PINNIPED_TEST_LDAP_USER_CN" echo " LDAP Username: $PINNIPED_TEST_LDAP_USER_CN"
echo " Password: $PINNIPED_TEST_LDAP_USER_PASSWORD" echo " LDAP Password: $PINNIPED_TEST_LDAP_USER_PASSWORD"
echo
fi fi
if [[ "$use_ad_upstream" == "yes" ]]; then if [[ "$use_ad_upstream" == "yes" ]]; then
echo " Username: $PINNIPED_TEST_AD_USER_USER_PRINCIPAL_NAME" echo " AD Username: $PINNIPED_TEST_AD_USER_USER_PRINCIPAL_NAME"
echo " Password: $PINNIPED_TEST_AD_USER_PASSWORD" echo " AD Password: $PINNIPED_TEST_AD_USER_PASSWORD"
echo
fi fi
# Perform a login using the kubectl plugin. This should print the URL to be followed for the Dex login page # Echo the commands that may be used to login and print the identity of the currently logged in user.
# if using an OIDC upstream, or should prompt on the CLI for username/password if using an LDAP upstream. # Once the CLI has cached your tokens, it will automatically refresh your short-lived credentials whenever
echo # they expire, so you should not be prompted to log in again for the rest of the day.
echo "Running: PINNIPED_DEBUG=true https_proxy=\"$PINNIPED_TEST_PROXY\" no_proxy=\"127.0.0.1\" kubectl --kubeconfig ./kubeconfig get pods -A" if [[ "$use_oidc_upstream" == "yes" ]]; then
PINNIPED_DEBUG=true https_proxy="$PINNIPED_TEST_PROXY" no_proxy="127.0.0.1" kubectl --kubeconfig ./kubeconfig get pods -A echo "To log in using OIDC, run:"
echo "PINNIPED_DEBUG=true https_proxy=\"$PINNIPED_TEST_PROXY\" no_proxy=\"127.0.0.1\" ./pinniped whoami --kubeconfig ./kubeconfig-oidc.yaml"
# Print the identity of the currently logged in user. The CLI has cached your tokens, and will automatically refresh echo
# your short-lived credentials whenever they expire, so you should not be prompted to log in again for the rest of the day. fi
echo if [[ "$use_ldap_upstream" == "yes" ]]; then
echo "Running: PINNIPED_DEBUG=true https_proxy=\"$PINNIPED_TEST_PROXY\" no_proxy=\"127.0.0.1\" ./pinniped whoami --kubeconfig ./kubeconfig" echo "To log in using LDAP, run:"
PINNIPED_DEBUG=true https_proxy="$PINNIPED_TEST_PROXY" no_proxy="127.0.0.1" ./pinniped whoami --kubeconfig ./kubeconfig echo "PINNIPED_DEBUG=true https_proxy=\"$PINNIPED_TEST_PROXY\" no_proxy=\"127.0.0.1\" ./pinniped whoami --kubeconfig ./kubeconfig-ldap.yaml"
echo
fi
if [[ "$use_ad_upstream" == "yes" ]]; then
echo "To log in using AD, run:"
echo "PINNIPED_DEBUG=true https_proxy=\"$PINNIPED_TEST_PROXY\" no_proxy=\"127.0.0.1\" ./pinniped whoami --kubeconfig ./kubeconfig-ad.yaml"
echo
fi

View File

@ -26,7 +26,7 @@ import (
"go.pinniped.dev/internal/controller/conditionsutil" "go.pinniped.dev/internal/controller/conditionsutil"
"go.pinniped.dev/internal/controller/supervisorconfig/upstreamwatchers" "go.pinniped.dev/internal/controller/supervisorconfig/upstreamwatchers"
"go.pinniped.dev/internal/controllerlib" "go.pinniped.dev/internal/controllerlib"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/upstreamldap" "go.pinniped.dev/internal/upstreamldap"
) )
@ -225,7 +225,7 @@ func (s *activeDirectoryUpstreamGenericLDAPStatus) Conditions() []metav1.Conditi
// UpstreamActiveDirectoryIdentityProviderICache is a thread safe cache that holds a list of validated upstream LDAP IDP configurations. // UpstreamActiveDirectoryIdentityProviderICache is a thread safe cache that holds a list of validated upstream LDAP IDP configurations.
type UpstreamActiveDirectoryIdentityProviderICache interface { type UpstreamActiveDirectoryIdentityProviderICache interface {
SetActiveDirectoryIdentityProviders([]provider.UpstreamLDAPIdentityProviderI) SetActiveDirectoryIdentityProviders([]upstreamprovider.UpstreamLDAPIdentityProviderI)
} }
type activeDirectoryWatcherController struct { type activeDirectoryWatcherController struct {
@ -299,7 +299,7 @@ func (c *activeDirectoryWatcherController) Sync(ctx controllerlib.Context) error
} }
requeue := false requeue := false
validatedUpstreams := make([]provider.UpstreamLDAPIdentityProviderI, 0, len(actualUpstreams)) validatedUpstreams := make([]upstreamprovider.UpstreamLDAPIdentityProviderI, 0, len(actualUpstreams))
for _, upstream := range actualUpstreams { for _, upstream := range actualUpstreams {
valid, requestedRequeue := c.validateUpstream(ctx.Context, upstream) valid, requestedRequeue := c.validateUpstream(ctx.Context, upstream)
if valid != nil { if valid != nil {
@ -318,7 +318,7 @@ func (c *activeDirectoryWatcherController) Sync(ctx controllerlib.Context) error
return nil return nil
} }
func (c *activeDirectoryWatcherController) validateUpstream(ctx context.Context, upstream *v1alpha1.ActiveDirectoryIdentityProvider) (p provider.UpstreamLDAPIdentityProviderI, requeue bool) { func (c *activeDirectoryWatcherController) validateUpstream(ctx context.Context, upstream *v1alpha1.ActiveDirectoryIdentityProvider) (p upstreamprovider.UpstreamLDAPIdentityProviderI, requeue bool) {
spec := upstream.Spec spec := upstream.Spec
adUpstreamImpl := &activeDirectoryUpstreamGenericLDAPImpl{activeDirectoryIdentityProvider: *upstream} adUpstreamImpl := &activeDirectoryUpstreamGenericLDAPImpl{activeDirectoryIdentityProvider: *upstream}
@ -344,7 +344,7 @@ func (c *activeDirectoryWatcherController) validateUpstream(ctx context.Context,
UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){ UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){
"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID"), "objectGUID": microsoftUUIDFromBinaryAttr("objectGUID"),
}, },
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
pwdLastSetAttribute: attributeUnchangedSinceLogin(pwdLastSetAttribute), pwdLastSetAttribute: attributeUnchangedSinceLogin(pwdLastSetAttribute),
userAccountControlAttribute: validUserAccountControl, userAccountControlAttribute: validUserAccountControl,
userAccountControlComputedAttribute: validComputedUserAccountControl, userAccountControlComputedAttribute: validComputedUserAccountControl,
@ -445,7 +445,7 @@ func getDomainFromDistinguishedName(distinguishedName string) (string, error) {
} }
//nolint:gochecknoglobals // this needs to be a global variable so that tests can check pointer equality //nolint:gochecknoglobals // this needs to be a global variable so that tests can check pointer equality
var validUserAccountControl = func(entry *ldap.Entry, _ provider.RefreshAttributes) error { var validUserAccountControl = func(entry *ldap.Entry, _ upstreamprovider.RefreshAttributes) error {
userAccountControl, err := strconv.Atoi(entry.GetAttributeValue(userAccountControlAttribute)) userAccountControl, err := strconv.Atoi(entry.GetAttributeValue(userAccountControlAttribute))
if err != nil { if err != nil {
return err return err
@ -459,7 +459,7 @@ var validUserAccountControl = func(entry *ldap.Entry, _ provider.RefreshAttribut
} }
//nolint:gochecknoglobals // this needs to be a global variable so that tests can check pointer equality //nolint:gochecknoglobals // this needs to be a global variable so that tests can check pointer equality
var validComputedUserAccountControl = func(entry *ldap.Entry, _ provider.RefreshAttributes) error { var validComputedUserAccountControl = func(entry *ldap.Entry, _ upstreamprovider.RefreshAttributes) error {
userAccountControl, err := strconv.Atoi(entry.GetAttributeValue(userAccountControlComputedAttribute)) userAccountControl, err := strconv.Atoi(entry.GetAttributeValue(userAccountControlComputedAttribute))
if err != nil { if err != nil {
return err return err
@ -473,8 +473,8 @@ var validComputedUserAccountControl = func(entry *ldap.Entry, _ provider.Refresh
} }
//nolint:gochecknoglobals // this needs to be a global variable so that tests can check pointer equality //nolint:gochecknoglobals // this needs to be a global variable so that tests can check pointer equality
var attributeUnchangedSinceLogin = func(attribute string) func(*ldap.Entry, provider.RefreshAttributes) error { var attributeUnchangedSinceLogin = func(attribute string) func(*ldap.Entry, upstreamprovider.RefreshAttributes) error {
return func(entry *ldap.Entry, storedAttributes provider.RefreshAttributes) error { return func(entry *ldap.Entry, storedAttributes upstreamprovider.RefreshAttributes) error {
prevAttributeValue := storedAttributes.AdditionalAttributes[attribute] prevAttributeValue := storedAttributes.AdditionalAttributes[attribute]
newValues := entry.GetRawAttributeValues(attribute) newValues := entry.GetRawAttributeValues(attribute)

View File

@ -31,6 +31,7 @@ import (
"go.pinniped.dev/internal/endpointaddr" "go.pinniped.dev/internal/endpointaddr"
"go.pinniped.dev/internal/mocks/mockldapconn" "go.pinniped.dev/internal/mocks/mockldapconn"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
"go.pinniped.dev/internal/upstreamldap" "go.pinniped.dev/internal/upstreamldap"
) )
@ -229,7 +230,7 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
GroupNameAttribute: testGroupSearchNameAttrName, GroupNameAttribute: testGroupSearchNameAttrName,
}, },
UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")}, UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
"pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"), "pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"),
"userAccountControl": validUserAccountControl, "userAccountControl": validUserAccountControl,
"msDS-User-Account-Control-Computed": validComputedUserAccountControl, "msDS-User-Account-Control-Computed": validComputedUserAccountControl,
@ -572,7 +573,7 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
GroupNameAttribute: testGroupSearchNameAttrName, GroupNameAttribute: testGroupSearchNameAttrName,
}, },
UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")}, UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
"pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"), "pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"),
"userAccountControl": validUserAccountControl, "userAccountControl": validUserAccountControl,
"msDS-User-Account-Control-Computed": validComputedUserAccountControl, "msDS-User-Account-Control-Computed": validComputedUserAccountControl,
@ -642,7 +643,7 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
GroupNameAttribute: "sAMAccountName", GroupNameAttribute: "sAMAccountName",
}, },
UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")}, UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
"pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"), "pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"),
"userAccountControl": validUserAccountControl, "userAccountControl": validUserAccountControl,
"msDS-User-Account-Control-Computed": validComputedUserAccountControl, "msDS-User-Account-Control-Computed": validComputedUserAccountControl,
@ -715,7 +716,7 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
GroupNameAttribute: testGroupSearchNameAttrName, GroupNameAttribute: testGroupSearchNameAttrName,
}, },
UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")}, UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
"pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"), "pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"),
"userAccountControl": validUserAccountControl, "userAccountControl": validUserAccountControl,
"msDS-User-Account-Control-Computed": validComputedUserAccountControl, "msDS-User-Account-Control-Computed": validComputedUserAccountControl,
@ -795,7 +796,7 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
GroupNameAttribute: testGroupSearchNameAttrName, GroupNameAttribute: testGroupSearchNameAttrName,
}, },
UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")}, UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
"pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"), "pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"),
"userAccountControl": validUserAccountControl, "userAccountControl": validUserAccountControl,
"msDS-User-Account-Control-Computed": validComputedUserAccountControl, "msDS-User-Account-Control-Computed": validComputedUserAccountControl,
@ -859,7 +860,7 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
GroupNameAttribute: testGroupSearchNameAttrName, GroupNameAttribute: testGroupSearchNameAttrName,
}, },
UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")}, UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
"pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"), "pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"),
"userAccountControl": validUserAccountControl, "userAccountControl": validUserAccountControl,
"msDS-User-Account-Control-Computed": validComputedUserAccountControl, "msDS-User-Account-Control-Computed": validComputedUserAccountControl,
@ -1010,7 +1011,7 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
GroupNameAttribute: testGroupSearchNameAttrName, GroupNameAttribute: testGroupSearchNameAttrName,
}, },
UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")}, UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
"pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"), "pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"),
"userAccountControl": validUserAccountControl, "userAccountControl": validUserAccountControl,
"msDS-User-Account-Control-Computed": validComputedUserAccountControl, "msDS-User-Account-Control-Computed": validComputedUserAccountControl,
@ -1160,7 +1161,7 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
GroupNameAttribute: testGroupSearchNameAttrName, GroupNameAttribute: testGroupSearchNameAttrName,
}, },
UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")}, UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
"pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"), "pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"),
"userAccountControl": validUserAccountControl, "userAccountControl": validUserAccountControl,
"msDS-User-Account-Control-Computed": validComputedUserAccountControl, "msDS-User-Account-Control-Computed": validComputedUserAccountControl,
@ -1232,7 +1233,7 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
GroupNameAttribute: testGroupSearchNameAttrName, GroupNameAttribute: testGroupSearchNameAttrName,
}, },
UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")}, UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
"pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"), "pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"),
"userAccountControl": validUserAccountControl, "userAccountControl": validUserAccountControl,
"msDS-User-Account-Control-Computed": validComputedUserAccountControl, "msDS-User-Account-Control-Computed": validComputedUserAccountControl,
@ -1499,7 +1500,7 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
}, },
UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")}, UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")},
GroupAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"sAMAccountName": groupSAMAccountNameWithDomainSuffix}, GroupAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"sAMAccountName": groupSAMAccountNameWithDomainSuffix},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
"pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"), "pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"),
"userAccountControl": validUserAccountControl, "userAccountControl": validUserAccountControl,
"msDS-User-Account-Control-Computed": validComputedUserAccountControl, "msDS-User-Account-Control-Computed": validComputedUserAccountControl,
@ -1559,7 +1560,7 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
GroupNameAttribute: testGroupSearchNameAttrName, GroupNameAttribute: testGroupSearchNameAttrName,
}, },
UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")}, UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
"pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"), "pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"),
"userAccountControl": validUserAccountControl, "userAccountControl": validUserAccountControl,
"msDS-User-Account-Control-Computed": validComputedUserAccountControl, "msDS-User-Account-Control-Computed": validComputedUserAccountControl,
@ -1623,7 +1624,7 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
GroupNameAttribute: testGroupSearchNameAttrName, GroupNameAttribute: testGroupSearchNameAttrName,
}, },
UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")}, UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
"pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"), "pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"),
"userAccountControl": validUserAccountControl, "userAccountControl": validUserAccountControl,
"msDS-User-Account-Control-Computed": validComputedUserAccountControl, "msDS-User-Account-Control-Computed": validComputedUserAccountControl,
@ -1687,7 +1688,7 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
GroupNameAttribute: testGroupSearchNameAttrName, GroupNameAttribute: testGroupSearchNameAttrName,
}, },
UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")}, UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
"pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"), "pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"),
"userAccountControl": validUserAccountControl, "userAccountControl": validUserAccountControl,
"msDS-User-Account-Control-Computed": validComputedUserAccountControl, "msDS-User-Account-Control-Computed": validComputedUserAccountControl,
@ -1899,7 +1900,7 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
GroupNameAttribute: testGroupSearchNameAttrName, GroupNameAttribute: testGroupSearchNameAttrName,
}, },
UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")}, UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
"pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"), "pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"),
"userAccountControl": validUserAccountControl, "userAccountControl": validUserAccountControl,
"msDS-User-Account-Control-Computed": validComputedUserAccountControl, "msDS-User-Account-Control-Computed": validComputedUserAccountControl,
@ -1962,7 +1963,7 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
SkipGroupRefresh: true, SkipGroupRefresh: true,
}, },
UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")}, UIDAttributeParsingOverrides: map[string]func(*ldap.Entry) (string, error){"objectGUID": microsoftUUIDFromBinaryAttr("objectGUID")},
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
"pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"), "pwdLastSet": attributeUnchangedSinceLogin("pwdLastSet"),
"userAccountControl": validUserAccountControl, "userAccountControl": validUserAccountControl,
"msDS-User-Account-Control-Computed": validComputedUserAccountControl, "msDS-User-Account-Control-Computed": validComputedUserAccountControl,
@ -2010,7 +2011,7 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
fakeKubeClient := fake.NewSimpleClientset(tt.inputSecrets...) fakeKubeClient := fake.NewSimpleClientset(tt.inputSecrets...)
kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0) kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0)
cache := provider.NewDynamicUpstreamIDPProvider() cache := provider.NewDynamicUpstreamIDPProvider()
cache.SetActiveDirectoryIdentityProviders([]provider.UpstreamLDAPIdentityProviderI{ cache.SetActiveDirectoryIdentityProviders([]upstreamprovider.UpstreamLDAPIdentityProviderI{
upstreamldap.New(upstreamldap.ProviderConfig{Name: "initial-entry"}), upstreamldap.New(upstreamldap.ProviderConfig{Name: "initial-entry"}),
}) })
@ -2104,8 +2105,8 @@ func TestActiveDirectoryUpstreamWatcherControllerSync(t *testing.T) {
expectedRefreshAttributeChecks := copyOfExpectedValueForResultingCache.RefreshAttributeChecks expectedRefreshAttributeChecks := copyOfExpectedValueForResultingCache.RefreshAttributeChecks
actualRefreshAttributeChecks := actualConfig.RefreshAttributeChecks actualRefreshAttributeChecks := actualConfig.RefreshAttributeChecks
copyOfExpectedValueForResultingCache.RefreshAttributeChecks = map[string]func(*ldap.Entry, provider.RefreshAttributes) error{} copyOfExpectedValueForResultingCache.RefreshAttributeChecks = map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{}
actualConfig.RefreshAttributeChecks = map[string]func(*ldap.Entry, provider.RefreshAttributes) error{} actualConfig.RefreshAttributeChecks = map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{}
require.Equal(t, len(expectedRefreshAttributeChecks), len(actualRefreshAttributeChecks)) require.Equal(t, len(expectedRefreshAttributeChecks), len(actualRefreshAttributeChecks))
for k, v := range expectedRefreshAttributeChecks { for k, v := range expectedRefreshAttributeChecks {
require.NotNil(t, actualRefreshAttributeChecks[k]) require.NotNil(t, actualRefreshAttributeChecks[k])
@ -2354,7 +2355,7 @@ func TestValidUserAccountControl(t *testing.T) {
for _, test := range tests { for _, test := range tests {
tt := test tt := test
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
err := validUserAccountControl(tt.entry, provider.RefreshAttributes{}) err := validUserAccountControl(tt.entry, upstreamprovider.RefreshAttributes{})
if tt.wantErr != "" { if tt.wantErr != "" {
require.Error(t, err) require.Error(t, err)
@ -2415,7 +2416,7 @@ func TestValidComputedUserAccountControl(t *testing.T) {
for _, test := range tests { for _, test := range tests {
tt := test tt := test
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
err := validComputedUserAccountControl(tt.entry, provider.RefreshAttributes{}) err := validComputedUserAccountControl(tt.entry, upstreamprovider.RefreshAttributes{})
if tt.wantErr != "" { if tt.wantErr != "" {
require.Error(t, err) require.Error(t, err)

View File

@ -1,4 +1,4 @@
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package supervisorconfig package supervisorconfig
@ -8,9 +8,11 @@ import (
"fmt" "fmt"
"net/url" "net/url"
"strings" "strings"
"time"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/errors" "k8s.io/apimachinery/pkg/util/errors"
"k8s.io/client-go/util/retry" "k8s.io/client-go/util/retry"
"k8s.io/klog/v2" "k8s.io/klog/v2"
@ -19,8 +21,11 @@ import (
configv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/config/v1alpha1" configv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/config/v1alpha1"
pinnipedclientset "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned" pinnipedclientset "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned"
configinformers "go.pinniped.dev/generated/latest/client/supervisor/informers/externalversions/config/v1alpha1" configinformers "go.pinniped.dev/generated/latest/client/supervisor/informers/externalversions/config/v1alpha1"
idpinformers "go.pinniped.dev/generated/latest/client/supervisor/informers/externalversions/idp/v1alpha1"
"go.pinniped.dev/internal/celtransformer"
pinnipedcontroller "go.pinniped.dev/internal/controller" pinnipedcontroller "go.pinniped.dev/internal/controller"
"go.pinniped.dev/internal/controllerlib" "go.pinniped.dev/internal/controllerlib"
"go.pinniped.dev/internal/idtransform"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
) )
@ -33,10 +38,14 @@ type ProvidersSetter interface {
} }
type federationDomainWatcherController struct { type federationDomainWatcherController struct {
providerSetter ProvidersSetter providerSetter ProvidersSetter
clock clock.Clock clock clock.Clock
client pinnipedclientset.Interface client pinnipedclientset.Interface
federationDomainInformer configinformers.FederationDomainInformer
federationDomainInformer configinformers.FederationDomainInformer
oidcIdentityProviderInformer idpinformers.OIDCIdentityProviderInformer
ldapIdentityProviderInformer idpinformers.LDAPIdentityProviderInformer
activeDirectoryIdentityProviderInformer idpinformers.ActiveDirectoryIdentityProviderInformer
} }
// NewFederationDomainWatcherController creates a controllerlib.Controller that watches // NewFederationDomainWatcherController creates a controllerlib.Controller that watches
@ -46,16 +55,22 @@ func NewFederationDomainWatcherController(
clock clock.Clock, clock clock.Clock,
client pinnipedclientset.Interface, client pinnipedclientset.Interface,
federationDomainInformer configinformers.FederationDomainInformer, federationDomainInformer configinformers.FederationDomainInformer,
oidcIdentityProviderInformer idpinformers.OIDCIdentityProviderInformer,
ldapIdentityProviderInformer idpinformers.LDAPIdentityProviderInformer,
activeDirectoryIdentityProviderInformer idpinformers.ActiveDirectoryIdentityProviderInformer,
withInformer pinnipedcontroller.WithInformerOptionFunc, withInformer pinnipedcontroller.WithInformerOptionFunc,
) controllerlib.Controller { ) controllerlib.Controller {
return controllerlib.New( return controllerlib.New(
controllerlib.Config{ controllerlib.Config{
Name: "FederationDomainWatcherController", Name: "FederationDomainWatcherController",
Syncer: &federationDomainWatcherController{ Syncer: &federationDomainWatcherController{
providerSetter: providerSetter, providerSetter: providerSetter,
clock: clock, clock: clock,
client: client, client: client,
federationDomainInformer: federationDomainInformer, federationDomainInformer: federationDomainInformer,
oidcIdentityProviderInformer: oidcIdentityProviderInformer,
ldapIdentityProviderInformer: ldapIdentityProviderInformer,
activeDirectoryIdentityProviderInformer: activeDirectoryIdentityProviderInformer,
}, },
}, },
withInformer( withInformer(
@ -63,6 +78,27 @@ func NewFederationDomainWatcherController(
pinnipedcontroller.MatchAnythingFilter(pinnipedcontroller.SingletonQueue()), pinnipedcontroller.MatchAnythingFilter(pinnipedcontroller.SingletonQueue()),
controllerlib.InformerOption{}, controllerlib.InformerOption{},
), ),
withInformer(
oidcIdentityProviderInformer,
// Since this controller only cares about IDP metadata names and UIDs (immutable fields),
// we only need to trigger Sync on creates and deletes.
pinnipedcontroller.MatchAnythingIgnoringUpdatesFilter(pinnipedcontroller.SingletonQueue()),
controllerlib.InformerOption{},
),
withInformer(
ldapIdentityProviderInformer,
// Since this controller only cares about IDP metadata names and UIDs (immutable fields),
// we only need to trigger Sync on creates and deletes.
pinnipedcontroller.MatchAnythingIgnoringUpdatesFilter(pinnipedcontroller.SingletonQueue()),
controllerlib.InformerOption{},
),
withInformer(
activeDirectoryIdentityProviderInformer,
// Since this controller only cares about IDP metadata names and UIDs (immutable fields),
// we only need to trigger Sync on creates and deletes.
pinnipedcontroller.MatchAnythingIgnoringUpdatesFilter(pinnipedcontroller.SingletonQueue()),
controllerlib.InformerOption{},
),
) )
} }
@ -143,8 +179,239 @@ func (c *federationDomainWatcherController) Sync(ctx controllerlib.Context) erro
continue continue
} }
federationDomainIssuer, err := provider.NewFederationDomainIssuer(federationDomain.Spec.Issuer) // This validates the Issuer URL. // TODO: Move all this identity provider stuff into helper functions. This is just a sketch of how the code would
// work in the sense that this is not doing error handling, is not validating everything that it should, and
// is not updating the status of the FederationDomain with anything related to these identity providers.
// This code may crash on invalid inputs since it is not handling any errors. However, when given valid inputs,
// this correctly implements the multiple IDPs features.
// Create the list of IDPs for this FederationDomain.
// Don't worry if the IDP CRs themselves is phase=Ready because those which are not ready will not be loaded
// into the provider cache, so they cannot actually be used to authenticate.
federationDomainIdentityProviders := []*provider.FederationDomainIdentityProvider{}
var defaultFederationDomainIdentityProvider *provider.FederationDomainIdentityProvider
if len(federationDomain.Spec.IdentityProviders) == 0 {
// When the FederationDomain does not list any IDPs, then we might be in backwards compatibility mode.
oidcIdentityProviders, _ := c.oidcIdentityProviderInformer.Lister().List(labels.Everything())
ldapIdentityProviders, _ := c.ldapIdentityProviderInformer.Lister().List(labels.Everything())
activeDirectoryIdentityProviders, _ := c.activeDirectoryIdentityProviderInformer.Lister().List(labels.Everything())
// TODO handle err return value for each of the above three lines
// Check if that there is exactly one IDP defined in the Supervisor namespace of any IDP CRD type.
idpCRsCount := len(oidcIdentityProviders) + len(ldapIdentityProviders) + len(activeDirectoryIdentityProviders)
if idpCRsCount == 1 {
// If so, default that IDP's DisplayName to be the same as its resource Name.
defaultFederationDomainIdentityProvider = &provider.FederationDomainIdentityProvider{}
switch {
case len(oidcIdentityProviders) == 1:
defaultFederationDomainIdentityProvider.DisplayName = oidcIdentityProviders[0].Name
defaultFederationDomainIdentityProvider.UID = oidcIdentityProviders[0].UID
case len(ldapIdentityProviders) == 1:
defaultFederationDomainIdentityProvider.DisplayName = ldapIdentityProviders[0].Name
defaultFederationDomainIdentityProvider.UID = ldapIdentityProviders[0].UID
case len(activeDirectoryIdentityProviders) == 1:
defaultFederationDomainIdentityProvider.DisplayName = activeDirectoryIdentityProviders[0].Name
defaultFederationDomainIdentityProvider.UID = activeDirectoryIdentityProviders[0].UID
}
// Backwards compatibility mode always uses an empty identity transformation pipline since no
// transformations are defined on the FederationDomain.
defaultFederationDomainIdentityProvider.Transforms = idtransform.NewTransformationPipeline()
plog.Warning("detected FederationDomain identity provider backwards compatibility mode: using the one existing identity provider for authentication",
"federationDomain", federationDomain.Name)
} else {
// There are no IDP CRs or there is more than one IDP CR. Either way, we are not in the backwards
// compatibility mode because there is not exactly one IDP CR in the namespace, despite the fact that no
// IDPs are listed on the FederationDomain. Create a FederationDomain which has no IDPs and therefore
// cannot actually be used to log in, but still serves endpoints.
// TODO: Write something into the FederationDomain's status to explain what's happening and how to fix it.
plog.Warning("FederationDomain has no identity providers listed and there is not exactly one identity provider defined in the namespace: authentication disabled",
"federationDomain", federationDomain.Name,
"namespace", federationDomain.Namespace,
"identityProvidersCustomResourcesCount", idpCRsCount,
)
}
}
// If there is an explicit list of IDPs on the FederationDomain, then process the list.
celTransformer, _ := celtransformer.NewCELTransformer(time.Second) // TODO: what is a good duration limit here?
// TODO: handle err
for _, idp := range federationDomain.Spec.IdentityProviders {
var idpResourceUID types.UID
var idpResourceName string
// TODO: Validate that all displayNames are unique within this FederationDomain's spec's list of identity providers.
// TODO: Validate that idp.ObjectRef.APIGroup is the expected APIGroup for IDP CRs "idp.supervisor.pinniped.dev"
// Validate that each objectRef resolves to an existing IDP. It does not matter if the IDP itself
// is phase=Ready, because it will not be loaded into the cache if not ready. For each objectRef
// that does not resolve, put an error on the FederationDomain status.
switch idp.ObjectRef.Kind {
case "LDAPIdentityProvider":
ldapIDP, _ := c.ldapIdentityProviderInformer.Lister().LDAPIdentityProviders(federationDomain.Namespace).Get(idp.ObjectRef.Name)
// TODO: handle notfound err and also unexpected errors
idpResourceName = ldapIDP.Name
idpResourceUID = ldapIDP.UID
case "ActiveDirectoryIdentityProvider":
adIDP, _ := c.activeDirectoryIdentityProviderInformer.Lister().ActiveDirectoryIdentityProviders(federationDomain.Namespace).Get(idp.ObjectRef.Name)
// TODO: handle notfound err and also unexpected errors
idpResourceName = adIDP.Name
idpResourceUID = adIDP.UID
case "OIDCIdentityProvider":
oidcIDP, _ := c.oidcIdentityProviderInformer.Lister().OIDCIdentityProviders(federationDomain.Namespace).Get(idp.ObjectRef.Name)
// TODO: handle notfound err and also unexpected errors
idpResourceName = oidcIDP.Name
idpResourceUID = oidcIDP.UID
default:
// TODO: handle bad user input
}
plog.Debug("resolved identity provider object reference",
"kind", idp.ObjectRef.Kind,
"name", idp.ObjectRef.Name,
"foundResourceName", idpResourceName,
"foundResourceUID", idpResourceUID,
)
// Prepare the transformations.
pipeline := idtransform.NewTransformationPipeline()
consts := &celtransformer.TransformationConstants{
StringConstants: map[string]string{},
StringListConstants: map[string][]string{},
}
// Read all the declared constants.
for _, c := range idp.Transforms.Constants {
switch c.Type {
case "string":
consts.StringConstants[c.Name] = c.StringValue
case "stringList":
consts.StringListConstants[c.Name] = c.StringListValue
default:
// TODO: this shouldn't really happen since the CRD validates it, but handle it as an error
}
}
// Compile all the expressions and add them to the pipeline.
for idx, e := range idp.Transforms.Expressions {
var rawTransform celtransformer.CELTransformation
switch e.Type {
case "username/v1":
rawTransform = &celtransformer.UsernameTransformation{Expression: e.Expression}
case "groups/v1":
rawTransform = &celtransformer.GroupsTransformation{Expression: e.Expression}
case "policy/v1":
rawTransform = &celtransformer.AllowAuthenticationPolicy{
Expression: e.Expression,
RejectedAuthenticationMessage: e.Message,
}
default:
// TODO: this shouldn't really happen since the CRD validates it, but handle it as an error
}
compiledTransform, err := celTransformer.CompileTransformation(rawTransform, consts)
if err != nil {
// TODO: handle compile err
plog.Error("error compiling identity transformation", err,
"federationDomain", federationDomain.Name,
"idpDisplayName", idp.DisplayName,
"transformationIndex", idx,
"transformationType", e.Type,
"transformationExpression", e.Expression,
)
}
pipeline.AppendTransformation(compiledTransform)
plog.Debug("successfully compiled identity transformation expression",
"type", e.Type,
"expr", e.Expression,
"policyMessage", e.Message,
)
}
// Run all the provided transform examples. If any fail, put errors on the FederationDomain status.
for idx, e := range idp.Transforms.Examples {
// TODO: use a real context param below
result, _ := pipeline.Evaluate(context.TODO(), e.Username, e.Groups)
// TODO: handle err
resultWasAuthRejected := !result.AuthenticationAllowed
if e.Expects.Rejected && !resultWasAuthRejected {
// TODO: handle this failed example
plog.Warning("FederationDomain identity provider transformations example failed: expected authentication to be rejected but it was not",
"federationDomain", federationDomain.Name,
"idpDisplayName", idp.DisplayName,
"exampleIndex", idx,
"expectedRejected", e.Expects.Rejected,
"actualRejectedResult", resultWasAuthRejected,
"expectedMessage", e.Expects.Message,
"actualMessageResult", result.RejectedAuthenticationMessage,
)
} else if !e.Expects.Rejected && resultWasAuthRejected {
// TODO: handle this failed example
plog.Warning("FederationDomain identity provider transformations example failed: expected authentication not to be rejected but it was rejected",
"federationDomain", federationDomain.Name,
"idpDisplayName", idp.DisplayName,
"exampleIndex", idx,
"expectedRejected", e.Expects.Rejected,
"actualRejectedResult", resultWasAuthRejected,
"expectedMessage", e.Expects.Message,
"actualMessageResult", result.RejectedAuthenticationMessage,
)
} else if e.Expects.Rejected && resultWasAuthRejected && e.Expects.Message != result.RejectedAuthenticationMessage {
// TODO: when expected message is blank, then treat it like it expects the default message
// TODO: handle this failed example
plog.Warning("FederationDomain identity provider transformations example failed: expected a different authentication rejection message",
"federationDomain", federationDomain.Name,
"idpDisplayName", idp.DisplayName,
"exampleIndex", idx,
"expectedRejected", e.Expects.Rejected,
"actualRejectedResult", resultWasAuthRejected,
"expectedMessage", e.Expects.Message,
"actualMessageResult", result.RejectedAuthenticationMessage,
)
} else if result.AuthenticationAllowed {
// In the case where the user expected the auth to be allowed and it was allowed, then compare
// the expected username and group names to the actual username and group names.
// TODO: when both of these fail, put both errors onto the status (not just the first one)
if e.Expects.Username != result.Username {
// TODO: handle this failed example
plog.Warning("FederationDomain identity provider transformations example failed: expected a different transformed username",
"federationDomain", federationDomain.Name,
"idpDisplayName", idp.DisplayName,
"exampleIndex", idx,
"expectedUsername", e.Expects.Username,
"actualUsernameResult", result.Username,
)
}
if !stringSlicesEqual(e.Expects.Groups, result.Groups) {
// TODO: Do we need to make this insensitive to ordering, or should the transformations evaluator be changed to always return sorted group names at the end of the pipeline?
// TODO: What happens if the user did not write any group expectation? Treat it like expecting any empty list of groups?
// TODO: handle this failed example
plog.Warning("FederationDomain identity provider transformations example failed: expected a different transformed groups list",
"federationDomain", federationDomain.Name,
"idpDisplayName", idp.DisplayName,
"exampleIndex", idx,
"expectedGroups", e.Expects.Groups,
"actualGroupsResult", result.Groups,
)
}
}
}
// For each valid IDP (unique displayName, valid objectRef + valid transforms), add it to the list.
federationDomainIdentityProviders = append(federationDomainIdentityProviders, &provider.FederationDomainIdentityProvider{
DisplayName: idp.DisplayName,
UID: idpResourceUID,
Transforms: pipeline,
})
plog.Debug("loaded FederationDomain identity provider",
"federationDomain", federationDomain.Name,
"identityProviderDisplayName", idp.DisplayName,
"identityProviderResourceUID", idpResourceUID,
)
}
// Now that we have the list of IDPs for this FederationDomain, create the issuer.
var federationDomainIssuer *provider.FederationDomainIssuer
err = nil
if defaultFederationDomainIdentityProvider != nil {
// This is the constructor for the backwards compatibility mode.
federationDomainIssuer, err = provider.NewFederationDomainIssuerWithDefaultIDP(federationDomain.Spec.Issuer, defaultFederationDomainIdentityProvider)
} else {
// This is the constructor for any other case, including when there is an empty list of IDPs.
federationDomainIssuer, err = provider.NewFederationDomainIssuer(federationDomain.Spec.Issuer, federationDomainIdentityProviders)
}
if err != nil { if err != nil {
// Note that the FederationDomainIssuer constructors validate the Issuer URL.
if err := c.updateStatus( if err := c.updateStatus(
ctx.Context, ctx.Context,
federationDomain.Namespace, federationDomain.Namespace,
@ -176,6 +443,18 @@ func (c *federationDomainWatcherController) Sync(ctx controllerlib.Context) erro
return errors.NewAggregate(errs) return errors.NewAggregate(errs)
} }
func stringSlicesEqual(a []string, b []string) bool {
if len(a) != len(b) {
return false
}
for i, itemFromA := range a {
if b[i] != itemFromA {
return false
}
}
return true
}
func (c *federationDomainWatcherController) updateStatus( func (c *federationDomainWatcherController) updateStatus(
ctx context.Context, ctx context.Context,
namespace, name string, namespace, name string,

View File

@ -20,7 +20,7 @@ import (
"go.pinniped.dev/internal/controller/conditionsutil" "go.pinniped.dev/internal/controller/conditionsutil"
"go.pinniped.dev/internal/controller/supervisorconfig/upstreamwatchers" "go.pinniped.dev/internal/controller/supervisorconfig/upstreamwatchers"
"go.pinniped.dev/internal/controllerlib" "go.pinniped.dev/internal/controllerlib"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/upstreamldap" "go.pinniped.dev/internal/upstreamldap"
) )
@ -133,7 +133,7 @@ func (s *ldapUpstreamGenericLDAPStatus) Conditions() []metav1.Condition {
// UpstreamLDAPIdentityProviderICache is a thread safe cache that holds a list of validated upstream LDAP IDP configurations. // UpstreamLDAPIdentityProviderICache is a thread safe cache that holds a list of validated upstream LDAP IDP configurations.
type UpstreamLDAPIdentityProviderICache interface { type UpstreamLDAPIdentityProviderICache interface {
SetLDAPIdentityProviders([]provider.UpstreamLDAPIdentityProviderI) SetLDAPIdentityProviders([]upstreamprovider.UpstreamLDAPIdentityProviderI)
} }
type ldapWatcherController struct { type ldapWatcherController struct {
@ -207,7 +207,7 @@ func (c *ldapWatcherController) Sync(ctx controllerlib.Context) error {
} }
requeue := false requeue := false
validatedUpstreams := make([]provider.UpstreamLDAPIdentityProviderI, 0, len(actualUpstreams)) validatedUpstreams := make([]upstreamprovider.UpstreamLDAPIdentityProviderI, 0, len(actualUpstreams))
for _, upstream := range actualUpstreams { for _, upstream := range actualUpstreams {
valid, requestedRequeue := c.validateUpstream(ctx.Context, upstream) valid, requestedRequeue := c.validateUpstream(ctx.Context, upstream)
if valid != nil { if valid != nil {
@ -226,7 +226,7 @@ func (c *ldapWatcherController) Sync(ctx controllerlib.Context) error {
return nil return nil
} }
func (c *ldapWatcherController) validateUpstream(ctx context.Context, upstream *v1alpha1.LDAPIdentityProvider) (p provider.UpstreamLDAPIdentityProviderI, requeue bool) { func (c *ldapWatcherController) validateUpstream(ctx context.Context, upstream *v1alpha1.LDAPIdentityProvider) (p upstreamprovider.UpstreamLDAPIdentityProviderI, requeue bool) {
spec := upstream.Spec spec := upstream.Spec
config := &upstreamldap.ProviderConfig{ config := &upstreamldap.ProviderConfig{

View File

@ -30,6 +30,7 @@ import (
"go.pinniped.dev/internal/endpointaddr" "go.pinniped.dev/internal/endpointaddr"
"go.pinniped.dev/internal/mocks/mockldapconn" "go.pinniped.dev/internal/mocks/mockldapconn"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
"go.pinniped.dev/internal/upstreamldap" "go.pinniped.dev/internal/upstreamldap"
) )
@ -1139,7 +1140,7 @@ func TestLDAPUpstreamWatcherControllerSync(t *testing.T) {
fakeKubeClient := fake.NewSimpleClientset(tt.inputSecrets...) fakeKubeClient := fake.NewSimpleClientset(tt.inputSecrets...)
kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0) kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0)
cache := provider.NewDynamicUpstreamIDPProvider() cache := provider.NewDynamicUpstreamIDPProvider()
cache.SetLDAPIdentityProviders([]provider.UpstreamLDAPIdentityProviderI{ cache.SetLDAPIdentityProviders([]upstreamprovider.UpstreamLDAPIdentityProviderI{
upstreamldap.New(upstreamldap.ProviderConfig{Name: "initial-entry"}), upstreamldap.New(upstreamldap.ProviderConfig{Name: "initial-entry"}),
}) })

View File

@ -35,7 +35,7 @@ import (
"go.pinniped.dev/internal/controller/supervisorconfig/upstreamwatchers" "go.pinniped.dev/internal/controller/supervisorconfig/upstreamwatchers"
"go.pinniped.dev/internal/controllerlib" "go.pinniped.dev/internal/controllerlib"
"go.pinniped.dev/internal/net/phttp" "go.pinniped.dev/internal/net/phttp"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/upstreamoidc" "go.pinniped.dev/internal/upstreamoidc"
) )
@ -91,7 +91,7 @@ var (
// UpstreamOIDCIdentityProviderICache is a thread safe cache that holds a list of validated upstream OIDC IDP configurations. // UpstreamOIDCIdentityProviderICache is a thread safe cache that holds a list of validated upstream OIDC IDP configurations.
type UpstreamOIDCIdentityProviderICache interface { type UpstreamOIDCIdentityProviderICache interface {
SetOIDCIdentityProviders([]provider.UpstreamOIDCIdentityProviderI) SetOIDCIdentityProviders([]upstreamprovider.UpstreamOIDCIdentityProviderI)
} }
// lruValidatorCache caches the *oidc.Provider associated with a particular issuer/TLS configuration. // lruValidatorCache caches the *oidc.Provider associated with a particular issuer/TLS configuration.
@ -175,13 +175,13 @@ func (c *oidcWatcherController) Sync(ctx controllerlib.Context) error {
} }
requeue := false requeue := false
validatedUpstreams := make([]provider.UpstreamOIDCIdentityProviderI, 0, len(actualUpstreams)) validatedUpstreams := make([]upstreamprovider.UpstreamOIDCIdentityProviderI, 0, len(actualUpstreams))
for _, upstream := range actualUpstreams { for _, upstream := range actualUpstreams {
valid := c.validateUpstream(ctx, upstream) valid := c.validateUpstream(ctx, upstream)
if valid == nil { if valid == nil {
requeue = true requeue = true
} else { } else {
validatedUpstreams = append(validatedUpstreams, provider.UpstreamOIDCIdentityProviderI(valid)) validatedUpstreams = append(validatedUpstreams, upstreamprovider.UpstreamOIDCIdentityProviderI(valid))
} }
} }
c.cache.SetOIDCIdentityProviders(validatedUpstreams) c.cache.SetOIDCIdentityProviders(validatedUpstreams)

View File

@ -29,6 +29,7 @@ import (
"go.pinniped.dev/internal/certauthority" "go.pinniped.dev/internal/certauthority"
"go.pinniped.dev/internal/controllerlib" "go.pinniped.dev/internal/controllerlib"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
"go.pinniped.dev/internal/testutil/oidctestutil" "go.pinniped.dev/internal/testutil/oidctestutil"
@ -81,7 +82,7 @@ func TestOIDCUpstreamWatcherControllerFilterSecret(t *testing.T) {
fakeKubeClient := fake.NewSimpleClientset() fakeKubeClient := fake.NewSimpleClientset()
kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0) kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0)
cache := provider.NewDynamicUpstreamIDPProvider() cache := provider.NewDynamicUpstreamIDPProvider()
cache.SetOIDCIdentityProviders([]provider.UpstreamOIDCIdentityProviderI{ cache.SetOIDCIdentityProviders([]upstreamprovider.UpstreamOIDCIdentityProviderI{
&upstreamoidc.ProviderConfig{Name: "initial-entry"}, &upstreamoidc.ProviderConfig{Name: "initial-entry"},
}) })
secretInformer := kubeInformers.Core().V1().Secrets() secretInformer := kubeInformers.Core().V1().Secrets()
@ -1416,7 +1417,7 @@ oidc: issuer did not match the issuer returned by provider, expected "` + testIs
kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0) kubeInformers := informers.NewSharedInformerFactory(fakeKubeClient, 0)
testLog := testlogger.NewLegacy(t) //nolint:staticcheck // old test with lots of log statements testLog := testlogger.NewLegacy(t) //nolint:staticcheck // old test with lots of log statements
cache := provider.NewDynamicUpstreamIDPProvider() cache := provider.NewDynamicUpstreamIDPProvider()
cache.SetOIDCIdentityProviders([]provider.UpstreamOIDCIdentityProviderI{ cache.SetOIDCIdentityProviders([]upstreamprovider.UpstreamOIDCIdentityProviderI{
&upstreamoidc.ProviderConfig{Name: "initial-entry"}, &upstreamoidc.ProviderConfig{Name: "initial-entry"},
}) })

View File

@ -16,7 +16,7 @@ import (
"go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1" "go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1"
"go.pinniped.dev/internal/constable" "go.pinniped.dev/internal/constable"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/upstreamldap" "go.pinniped.dev/internal/upstreamldap"
) )
@ -365,7 +365,7 @@ func validateAndSetLDAPServerConnectivityAndSearchBase(
return ldapConnectionValidCondition, searchBaseFoundCondition return ldapConnectionValidCondition, searchBaseFoundCondition
} }
func EvaluateConditions(conditions GradatedConditions, config *upstreamldap.ProviderConfig) (provider.UpstreamLDAPIdentityProviderI, bool) { func EvaluateConditions(conditions GradatedConditions, config *upstreamldap.ProviderConfig) (upstreamprovider.UpstreamLDAPIdentityProviderI, bool) {
for _, gradatedCondition := range conditions.gradatedConditions { for _, gradatedCondition := range conditions.gradatedConditions {
if gradatedCondition.condition.Status != metav1.ConditionTrue && gradatedCondition.isFatal { if gradatedCondition.condition.Status != metav1.ConditionTrue && gradatedCondition.isFatal {
// Invalid provider, so do not load it into the cache. // Invalid provider, so do not load it into the cache.

View File

@ -1,4 +1,4 @@
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package supervisorstorage package supervisorstorage
@ -27,6 +27,7 @@ import (
"go.pinniped.dev/internal/fositestorage/pkce" "go.pinniped.dev/internal/fositestorage/pkce"
"go.pinniped.dev/internal/fositestorage/refreshtoken" "go.pinniped.dev/internal/fositestorage/refreshtoken"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/psession"
) )
@ -43,7 +44,7 @@ type garbageCollectorController struct {
// UpstreamOIDCIdentityProviderICache is a thread safe cache that holds a list of validated upstream OIDC IDP configurations. // UpstreamOIDCIdentityProviderICache is a thread safe cache that holds a list of validated upstream OIDC IDP configurations.
type UpstreamOIDCIdentityProviderICache interface { type UpstreamOIDCIdentityProviderICache interface {
GetOIDCIdentityProviders() []provider.UpstreamOIDCIdentityProviderI GetOIDCIdentityProviders() []upstreamprovider.UpstreamOIDCIdentityProviderI
} }
func GarbageCollectorController( func GarbageCollectorController(
@ -244,7 +245,7 @@ func (c *garbageCollectorController) tryRevokeUpstreamOIDCToken(ctx context.Cont
} }
// Try to find the provider that was originally used to create the stored session. // Try to find the provider that was originally used to create the stored session.
var foundOIDCIdentityProviderI provider.UpstreamOIDCIdentityProviderI var foundOIDCIdentityProviderI upstreamprovider.UpstreamOIDCIdentityProviderI
for _, p := range c.idpCache.GetOIDCIdentityProviders() { for _, p := range c.idpCache.GetOIDCIdentityProviders() {
if p.GetName() == customSessionData.ProviderName && p.GetResourceUID() == customSessionData.ProviderUID { if p.GetName() == customSessionData.ProviderName && p.GetResourceUID() == customSessionData.ProviderUID {
foundOIDCIdentityProviderI = p foundOIDCIdentityProviderI = p
@ -260,7 +261,7 @@ func (c *garbageCollectorController) tryRevokeUpstreamOIDCToken(ctx context.Cont
upstreamAccessToken := customSessionData.OIDC.UpstreamAccessToken upstreamAccessToken := customSessionData.OIDC.UpstreamAccessToken
if upstreamRefreshToken != "" { if upstreamRefreshToken != "" {
err := foundOIDCIdentityProviderI.RevokeToken(ctx, upstreamRefreshToken, provider.RefreshTokenType) err := foundOIDCIdentityProviderI.RevokeToken(ctx, upstreamRefreshToken, upstreamprovider.RefreshTokenType)
if err != nil { if err != nil {
return err return err
} }
@ -268,7 +269,7 @@ func (c *garbageCollectorController) tryRevokeUpstreamOIDCToken(ctx context.Cont
} }
if upstreamAccessToken != "" { if upstreamAccessToken != "" {
err := foundOIDCIdentityProviderI.RevokeToken(ctx, upstreamAccessToken, provider.AccessTokenType) err := foundOIDCIdentityProviderI.RevokeToken(ctx, upstreamAccessToken, upstreamprovider.AccessTokenType)
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,4 +1,4 @@
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package supervisorstorage package supervisorstorage
@ -30,6 +30,7 @@ import (
"go.pinniped.dev/internal/fositestorage/refreshtoken" "go.pinniped.dev/internal/fositestorage/refreshtoken"
"go.pinniped.dev/internal/oidc/clientregistry" "go.pinniped.dev/internal/oidc/clientregistry"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/psession"
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
"go.pinniped.dev/internal/testutil/oidctestutil" "go.pinniped.dev/internal/testutil/oidctestutil"
@ -369,7 +370,7 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
&oidctestutil.RevokeTokenArgs{ &oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context, Ctx: syncContext.Context,
Token: "fake-upstream-refresh-token", Token: "fake-upstream-refresh-token",
TokenType: provider.RefreshTokenType, TokenType: upstreamprovider.RefreshTokenType,
}, },
) )
@ -493,7 +494,7 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
&oidctestutil.RevokeTokenArgs{ &oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context, Ctx: syncContext.Context,
Token: "fake-upstream-access-token", Token: "fake-upstream-access-token",
TokenType: provider.AccessTokenType, TokenType: upstreamprovider.AccessTokenType,
}, },
) )
@ -785,7 +786,7 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
&oidctestutil.RevokeTokenArgs{ &oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context, Ctx: syncContext.Context,
Token: "fake-upstream-refresh-token", Token: "fake-upstream-refresh-token",
TokenType: provider.RefreshTokenType, TokenType: upstreamprovider.RefreshTokenType,
}, },
) )
@ -810,7 +811,7 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
&oidctestutil.RevokeTokenArgs{ &oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context, Ctx: syncContext.Context,
Token: "fake-upstream-refresh-token", Token: "fake-upstream-refresh-token",
TokenType: provider.RefreshTokenType, TokenType: upstreamprovider.RefreshTokenType,
}, },
) )
@ -889,7 +890,7 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
&oidctestutil.RevokeTokenArgs{ &oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context, Ctx: syncContext.Context,
Token: "fake-upstream-refresh-token", Token: "fake-upstream-refresh-token",
TokenType: provider.RefreshTokenType, TokenType: upstreamprovider.RefreshTokenType,
}, },
) )
@ -1012,7 +1013,7 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
&oidctestutil.RevokeTokenArgs{ &oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context, Ctx: syncContext.Context,
Token: "fake-upstream-refresh-token", Token: "fake-upstream-refresh-token",
TokenType: provider.RefreshTokenType, TokenType: upstreamprovider.RefreshTokenType,
}, },
) )
@ -1136,7 +1137,7 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
&oidctestutil.RevokeTokenArgs{ &oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context, Ctx: syncContext.Context,
Token: "fake-upstream-access-token", Token: "fake-upstream-access-token",
TokenType: provider.AccessTokenType, TokenType: upstreamprovider.AccessTokenType,
}, },
) )
@ -1214,7 +1215,7 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
&oidctestutil.RevokeTokenArgs{ &oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context, Ctx: syncContext.Context,
Token: "fake-upstream-refresh-token", Token: "fake-upstream-refresh-token",
TokenType: provider.RefreshTokenType, TokenType: upstreamprovider.RefreshTokenType,
}, },
) )
@ -1291,7 +1292,7 @@ func TestGarbageCollectorControllerSync(t *testing.T) {
&oidctestutil.RevokeTokenArgs{ &oidctestutil.RevokeTokenArgs{
Ctx: syncContext.Context, Ctx: syncContext.Context,
Token: "fake-upstream-access-token", Token: "fake-upstream-access-token",
TokenType: provider.AccessTokenType, TokenType: upstreamprovider.AccessTokenType,
}, },
) )

View File

@ -1,4 +1,4 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package controller package controller
@ -18,6 +18,16 @@ func NameAndNamespaceExactMatchFilterFactory(name, namespace string) controllerl
}, nil) }, nil)
} }
// MatchAnythingIgnoringUpdatesFilter returns a controllerlib.Filter that allows all objects but ignores updates.
func MatchAnythingIgnoringUpdatesFilter(parentFunc controllerlib.ParentFunc) controllerlib.Filter {
return controllerlib.FilterFuncs{
AddFunc: func(object metav1.Object) bool { return true },
UpdateFunc: func(oldObj, newObj metav1.Object) bool { return false },
DeleteFunc: func(object metav1.Object) bool { return true },
ParentFunc: parentFunc,
}
}
// MatchAnythingFilter returns a controllerlib.Filter that allows all objects. // MatchAnythingFilter returns a controllerlib.Filter that allows all objects.
func MatchAnythingFilter(parentFunc controllerlib.ParentFunc) controllerlib.Filter { func MatchAnythingFilter(parentFunc controllerlib.ParentFunc) controllerlib.Filter {
return SimpleFilter(func(object metav1.Object) bool { return true }, parentFunc) return SimpleFilter(func(object metav1.Object) bool { return true }, parentFunc)

View File

@ -1,4 +1,4 @@
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package accesstoken package accesstoken
@ -31,7 +31,8 @@ const (
// Version 2 is when we switched to storing psession.PinnipedSession inside the fosite request. // Version 2 is when we switched to storing psession.PinnipedSession inside the fosite request.
// Version 3 is when we added the Username field to the psession.CustomSessionData. // Version 3 is when we added the Username field to the psession.CustomSessionData.
// Version 4 is when fosite added json tags to their openid.DefaultSession struct. // Version 4 is when fosite added json tags to their openid.DefaultSession struct.
accessTokenStorageVersion = "4" // Version 5 is when we added the UpstreamUsername and UpstreamGroups fields to psession.CustomSessionData.
accessTokenStorageVersion = "5"
) )
type RevocationStorage interface { type RevocationStorage interface {

View File

@ -1,4 +1,4 @@
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package authorizationcode package authorizationcode
@ -32,7 +32,8 @@ const (
// Version 2 is when we switched to storing psession.PinnipedSession inside the fosite request. // Version 2 is when we switched to storing psession.PinnipedSession inside the fosite request.
// Version 3 is when we added the Username field to the psession.CustomSessionData. // Version 3 is when we added the Username field to the psession.CustomSessionData.
// Version 4 is when fosite added json tags to their openid.DefaultSession struct. // Version 4 is when fosite added json tags to their openid.DefaultSession struct.
authorizeCodeStorageVersion = "4" // Version 5 is when we added the UpstreamUsername and UpstreamGroups fields to psession.CustomSessionData.
authorizeCodeStorageVersion = "5"
) )
var _ oauth2.AuthorizeCodeStorage = &authorizeCodeStorage{} var _ oauth2.AuthorizeCodeStorage = &authorizeCodeStorage{}

View File

@ -1,4 +1,4 @@
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package openidconnect package openidconnect
@ -32,7 +32,8 @@ const (
// Version 2 is when we switched to storing psession.PinnipedSession inside the fosite request. // Version 2 is when we switched to storing psession.PinnipedSession inside the fosite request.
// Version 3 is when we added the Username field to the psession.CustomSessionData. // Version 3 is when we added the Username field to the psession.CustomSessionData.
// Version 4 is when fosite added json tags to their openid.DefaultSession struct. // Version 4 is when fosite added json tags to their openid.DefaultSession struct.
oidcStorageVersion = "4" // Version 5 is when we added the UpstreamUsername and UpstreamGroups fields to psession.CustomSessionData.
oidcStorageVersion = "5"
) )
var _ openid.OpenIDConnectRequestStorage = &openIDConnectRequestStorage{} var _ openid.OpenIDConnectRequestStorage = &openIDConnectRequestStorage{}

View File

@ -1,4 +1,4 @@
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package pkce package pkce
@ -30,7 +30,8 @@ const (
// Version 2 is when we switched to storing psession.PinnipedSession inside the fosite request. // Version 2 is when we switched to storing psession.PinnipedSession inside the fosite request.
// Version 3 is when we added the Username field to the psession.CustomSessionData. // Version 3 is when we added the Username field to the psession.CustomSessionData.
// Version 4 is when fosite added json tags to their openid.DefaultSession struct. // Version 4 is when fosite added json tags to their openid.DefaultSession struct.
pkceStorageVersion = "4" // Version 5 is when we added the UpstreamUsername and UpstreamGroups fields to psession.CustomSessionData.
pkceStorageVersion = "5"
) )
var _ pkce.PKCERequestStorage = &pkceStorage{} var _ pkce.PKCERequestStorage = &pkceStorage{}

View File

@ -1,4 +1,4 @@
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package refreshtoken package refreshtoken
@ -31,7 +31,8 @@ const (
// Version 2 is when we switched to storing psession.PinnipedSession inside the fosite request. // Version 2 is when we switched to storing psession.PinnipedSession inside the fosite request.
// Version 3 is when we added the Username field to the psession.CustomSessionData. // Version 3 is when we added the Username field to the psession.CustomSessionData.
// Version 4 is when fosite added json tags to their openid.DefaultSession struct. // Version 4 is when fosite added json tags to their openid.DefaultSession struct.
refreshTokenStorageVersion = "4" // Version 5 is when we added the UpstreamUsername and UpstreamGroups fields to psession.CustomSessionData.
refreshTokenStorageVersion = "5"
) )
type RevocationStorage interface { type RevocationStorage interface {

View File

@ -1,6 +1,6 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package mockupstreamoidcidentityprovider package mockupstreamoidcidentityprovider
//go:generate go run -v github.com/golang/mock/mockgen -destination=mockupstreamoidcidentityprovider.go -package=mockupstreamoidcidentityprovider -copyright_file=../../../hack/header.txt go.pinniped.dev/internal/oidc/provider UpstreamOIDCIdentityProviderI //go:generate go run -v github.com/golang/mock/mockgen -destination=mockupstreamoidcidentityprovider.go -package=mockupstreamoidcidentityprovider -copyright_file=../../../hack/header.txt go.pinniped.dev/internal/oidc/provider/upstreamprovider UpstreamOIDCIdentityProviderI

View File

@ -3,7 +3,7 @@
// //
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: go.pinniped.dev/internal/oidc/provider (interfaces: UpstreamOIDCIdentityProviderI) // Source: go.pinniped.dev/internal/oidc/provider/upstreamprovider (interfaces: UpstreamOIDCIdentityProviderI)
// Package mockupstreamoidcidentityprovider is a generated GoMock package. // Package mockupstreamoidcidentityprovider is a generated GoMock package.
package mockupstreamoidcidentityprovider package mockupstreamoidcidentityprovider
@ -14,7 +14,7 @@ import (
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
provider "go.pinniped.dev/internal/oidc/provider" upstreamprovider "go.pinniped.dev/internal/oidc/provider/upstreamprovider"
nonce "go.pinniped.dev/pkg/oidcclient/nonce" nonce "go.pinniped.dev/pkg/oidcclient/nonce"
oidctypes "go.pinniped.dev/pkg/oidcclient/oidctypes" oidctypes "go.pinniped.dev/pkg/oidcclient/oidctypes"
pkce "go.pinniped.dev/pkg/oidcclient/pkce" pkce "go.pinniped.dev/pkg/oidcclient/pkce"
@ -245,7 +245,7 @@ func (mr *MockUpstreamOIDCIdentityProviderIMockRecorder) PerformRefresh(arg0, ar
} }
// RevokeToken mocks base method. // RevokeToken mocks base method.
func (m *MockUpstreamOIDCIdentityProviderI) RevokeToken(arg0 context.Context, arg1 string, arg2 provider.RevocableTokenType) error { func (m *MockUpstreamOIDCIdentityProviderI) RevokeToken(arg0 context.Context, arg1 string, arg2 upstreamprovider.RevocableTokenType) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RevokeToken", arg0, arg1, arg2) ret := m.ctrl.Call(m, "RevokeToken", arg0, arg1, arg2)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)

View File

@ -18,12 +18,14 @@ import (
oidcapi "go.pinniped.dev/generated/latest/apis/supervisor/oidc" oidcapi "go.pinniped.dev/generated/latest/apis/supervisor/oidc"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/httputil/securityheader" "go.pinniped.dev/internal/httputil/securityheader"
"go.pinniped.dev/internal/idtransform"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/downstreamsession" "go.pinniped.dev/internal/oidc/downstreamsession"
"go.pinniped.dev/internal/oidc/login" "go.pinniped.dev/internal/oidc/login"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidc/provider/formposthtml" "go.pinniped.dev/internal/oidc/provider/formposthtml"
"go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/psession"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
@ -37,7 +39,7 @@ const (
func NewHandler( func NewHandler(
downstreamIssuer string, downstreamIssuer string,
idpLister oidc.UpstreamIdentityProvidersLister, idpFinder provider.FederationDomainIdentityProvidersFinderI,
oauthHelperWithoutStorage fosite.OAuth2Provider, oauthHelperWithoutStorage fosite.OAuth2Provider,
oauthHelperWithStorage fosite.OAuth2Provider, oauthHelperWithStorage fosite.OAuth2Provider,
generateCSRF func() (csrftoken.CSRFToken, error), generateCSRF func() (csrftoken.CSRFToken, error),
@ -57,20 +59,25 @@ func NewHandler(
// Note that the client might have used oidcapi.AuthorizeUpstreamIDPNameParamName and // Note that the client might have used oidcapi.AuthorizeUpstreamIDPNameParamName and
// oidcapi.AuthorizeUpstreamIDPTypeParamName query params to request a certain upstream IDP. // oidcapi.AuthorizeUpstreamIDPTypeParamName query params to request a certain upstream IDP.
// The Pinniped CLI has been sending these params since v0.9.0. // The Pinniped CLI has been sending these params since v0.9.0.
// Currently, these are ignored because the Supervisor does not yet support logins when multiple IDPs idpNameQueryParamValue := r.URL.Query().Get(oidcapi.AuthorizeUpstreamIDPNameParamName)
// are configured. However, these params should be honored in the future when choosing an upstream oidcUpstream, ldapUpstream, err := chooseUpstreamIDP(idpNameQueryParamValue, idpFinder)
// here, e.g. by calling oidcapi.FindUpstreamIDPByNameAndType() when the params are present.
oidcUpstream, ldapUpstream, idpType, err := chooseUpstreamIDP(idpLister)
if err != nil { if err != nil {
plog.WarningErr("authorize upstream config", err) plog.WarningErr("authorize upstream config", err)
return err return err
} }
if idpType == psession.ProviderTypeOIDC { if oidcUpstream != nil {
if len(r.Header.Values(oidcapi.AuthorizeUsernameHeaderName)) > 0 || if len(r.Header.Values(oidcapi.AuthorizeUsernameHeaderName)) > 0 ||
len(r.Header.Values(oidcapi.AuthorizePasswordHeaderName)) > 0 { len(r.Header.Values(oidcapi.AuthorizePasswordHeaderName)) > 0 {
// The client set a username header, so they are trying to log in with a username/password. // The client set a username header, so they are trying to log in with a username/password.
return handleAuthRequestForOIDCUpstreamPasswordGrant(r, w, oauthHelperWithStorage, oidcUpstream) return handleAuthRequestForOIDCUpstreamPasswordGrant(
r,
w,
oauthHelperWithStorage,
oidcUpstream.Provider,
oidcUpstream.Transforms,
idpNameQueryParamValue,
)
} }
return handleAuthRequestForOIDCUpstreamBrowserFlow(r, w, return handleAuthRequestForOIDCUpstreamBrowserFlow(r, w,
oauthHelperWithoutStorage, oauthHelperWithoutStorage,
@ -79,6 +86,7 @@ func NewHandler(
downstreamIssuer, downstreamIssuer,
upstreamStateEncoder, upstreamStateEncoder,
cookieCodec, cookieCodec,
idpNameQueryParamValue,
) )
} }
@ -88,8 +96,10 @@ func NewHandler(
// The client set a username header, so they are trying to log in with a username/password. // The client set a username header, so they are trying to log in with a username/password.
return handleAuthRequestForLDAPUpstreamCLIFlow(r, w, return handleAuthRequestForLDAPUpstreamCLIFlow(r, w,
oauthHelperWithStorage, oauthHelperWithStorage,
ldapUpstream, ldapUpstream.Provider,
idpType, ldapUpstream.SessionProviderType,
ldapUpstream.Transforms,
idpNameQueryParamValue,
) )
} }
return handleAuthRequestForLDAPUpstreamBrowserFlow( return handleAuthRequestForLDAPUpstreamBrowserFlow(
@ -100,10 +110,11 @@ func NewHandler(
generateNonce, generateNonce,
generatePKCE, generatePKCE,
ldapUpstream, ldapUpstream,
idpType, ldapUpstream.SessionProviderType,
downstreamIssuer, downstreamIssuer,
upstreamStateEncoder, upstreamStateEncoder,
cookieCodec, cookieCodec,
idpNameQueryParamValue,
) )
}) })
@ -117,24 +128,28 @@ func handleAuthRequestForLDAPUpstreamCLIFlow(
r *http.Request, r *http.Request,
w http.ResponseWriter, w http.ResponseWriter,
oauthHelper fosite.OAuth2Provider, oauthHelper fosite.OAuth2Provider,
ldapUpstream provider.UpstreamLDAPIdentityProviderI, ldapUpstream upstreamprovider.UpstreamLDAPIdentityProviderI,
idpType psession.ProviderType, idpType psession.ProviderType,
identityTransforms *idtransform.TransformationPipeline,
idpNameQueryParamValue string,
) error { ) error {
authorizeRequester, created := newAuthorizeRequest(r, w, oauthHelper, true) authorizeRequester, created := newAuthorizeRequest(r, w, oauthHelper, true)
if !created { if !created {
return nil return nil
} }
maybeLogDeprecationWarningForMissingIDPParam(idpNameQueryParamValue, authorizeRequester)
if !requireStaticClientForUsernameAndPasswordHeaders(r, w, oauthHelper, authorizeRequester) { if !requireStaticClientForUsernameAndPasswordHeaders(r, w, oauthHelper, authorizeRequester) {
return nil return nil
} }
username, password, hadUsernamePasswordValues := requireNonEmptyUsernameAndPasswordHeaders(r, w, oauthHelper, authorizeRequester) submittedUsername, submittedPassword, hadUsernamePasswordValues := requireNonEmptyUsernameAndPasswordHeaders(r, w, oauthHelper, authorizeRequester)
if !hadUsernamePasswordValues { if !hadUsernamePasswordValues {
return nil return nil
} }
authenticateResponse, authenticated, err := ldapUpstream.AuthenticateUser(r.Context(), username, password, authorizeRequester.GetGrantedScopes()) authenticateResponse, authenticated, err := ldapUpstream.AuthenticateUser(r.Context(), submittedUsername, submittedPassword, authorizeRequester.GetGrantedScopes())
if err != nil { if err != nil {
plog.WarningErr("unexpected error during upstream LDAP authentication", err, "upstreamName", ldapUpstream.GetName()) plog.WarningErr("unexpected error during upstream LDAP authentication", err, "upstreamName", ldapUpstream.GetName())
return httperr.New(http.StatusBadGateway, "unexpected error during upstream authentication") return httperr.New(http.StatusBadGateway, "unexpected error during upstream authentication")
@ -146,9 +161,18 @@ func handleAuthRequestForLDAPUpstreamCLIFlow(
} }
subject := downstreamsession.DownstreamSubjectFromUpstreamLDAP(ldapUpstream, authenticateResponse) subject := downstreamsession.DownstreamSubjectFromUpstreamLDAP(ldapUpstream, authenticateResponse)
username = authenticateResponse.User.GetName() upstreamUsername := authenticateResponse.User.GetName()
groups := authenticateResponse.User.GetGroups() upstreamGroups := authenticateResponse.User.GetGroups()
customSessionData := downstreamsession.MakeDownstreamLDAPOrADCustomSessionData(ldapUpstream, idpType, authenticateResponse, username)
username, groups, err := downstreamsession.ApplyIdentityTransformations(r.Context(), identityTransforms, upstreamUsername, upstreamGroups)
if err != nil {
oidc.WriteAuthorizeError(r, w, oauthHelper, authorizeRequester,
fosite.ErrAccessDenied.WithHintf("Reason: %s.", err.Error()), true,
)
return nil
}
customSessionData := downstreamsession.MakeDownstreamLDAPOrADCustomSessionData(ldapUpstream, idpType, authenticateResponse, username, upstreamUsername, upstreamGroups)
openIDSession := downstreamsession.MakeDownstreamSession(subject, username, groups, openIDSession := downstreamsession.MakeDownstreamSession(subject, username, groups,
authorizeRequester.GetGrantedScopes(), authorizeRequester.GetClient().GetID(), customSessionData, map[string]interface{}{}) authorizeRequester.GetGrantedScopes(), authorizeRequester.GetClient().GetID(), customSessionData, map[string]interface{}{})
oidc.PerformAuthcodeRedirect(r, w, oauthHelper, authorizeRequester, openIDSession, true) oidc.PerformAuthcodeRedirect(r, w, oauthHelper, authorizeRequester, openIDSession, true)
@ -163,11 +187,12 @@ func handleAuthRequestForLDAPUpstreamBrowserFlow(
generateCSRF func() (csrftoken.CSRFToken, error), generateCSRF func() (csrftoken.CSRFToken, error),
generateNonce func() (nonce.Nonce, error), generateNonce func() (nonce.Nonce, error),
generatePKCE func() (pkce.Code, error), generatePKCE func() (pkce.Code, error),
ldapUpstream provider.UpstreamLDAPIdentityProviderI, ldapUpstream *provider.FederationDomainResolvedLDAPIdentityProvider,
idpType psession.ProviderType, idpType psession.ProviderType,
downstreamIssuer string, downstreamIssuer string,
upstreamStateEncoder oidc.Encoder, upstreamStateEncoder oidc.Encoder,
cookieCodec oidc.Codec, cookieCodec oidc.Codec,
idpNameQueryParamValue string,
) error { ) error {
authRequestState, err := handleBrowserFlowAuthRequest( authRequestState, err := handleBrowserFlowAuthRequest(
r, r,
@ -176,10 +201,11 @@ func handleAuthRequestForLDAPUpstreamBrowserFlow(
generateCSRF, generateCSRF,
generateNonce, generateNonce,
generatePKCE, generatePKCE,
ldapUpstream.GetName(), ldapUpstream.DisplayName,
idpType, idpType,
cookieCodec, cookieCodec,
upstreamStateEncoder, upstreamStateEncoder,
idpNameQueryParamValue,
) )
if err != nil { if err != nil {
return err return err
@ -196,18 +222,22 @@ func handleAuthRequestForOIDCUpstreamPasswordGrant(
r *http.Request, r *http.Request,
w http.ResponseWriter, w http.ResponseWriter,
oauthHelper fosite.OAuth2Provider, oauthHelper fosite.OAuth2Provider,
oidcUpstream provider.UpstreamOIDCIdentityProviderI, oidcUpstream upstreamprovider.UpstreamOIDCIdentityProviderI,
identityTransforms *idtransform.TransformationPipeline,
idpNameQueryParamValue string,
) error { ) error {
authorizeRequester, created := newAuthorizeRequest(r, w, oauthHelper, true) authorizeRequester, created := newAuthorizeRequest(r, w, oauthHelper, true)
if !created { if !created {
return nil return nil
} }
maybeLogDeprecationWarningForMissingIDPParam(idpNameQueryParamValue, authorizeRequester)
if !requireStaticClientForUsernameAndPasswordHeaders(r, w, oauthHelper, authorizeRequester) { if !requireStaticClientForUsernameAndPasswordHeaders(r, w, oauthHelper, authorizeRequester) {
return nil return nil
} }
username, password, hadUsernamePasswordValues := requireNonEmptyUsernameAndPasswordHeaders(r, w, oauthHelper, authorizeRequester) submittedUsername, submittedPassword, hadUsernamePasswordValues := requireNonEmptyUsernameAndPasswordHeaders(r, w, oauthHelper, authorizeRequester)
if !hadUsernamePasswordValues { if !hadUsernamePasswordValues {
return nil return nil
} }
@ -220,7 +250,7 @@ func handleAuthRequestForOIDCUpstreamPasswordGrant(
return nil return nil
} }
token, err := oidcUpstream.PasswordCredentialsGrantAndValidateTokens(r.Context(), username, password) token, err := oidcUpstream.PasswordCredentialsGrantAndValidateTokens(r.Context(), submittedUsername, submittedPassword)
if err != nil { if err != nil {
// Upstream password grant errors can be generic errors (e.g. a network failure) or can be oauth2.RetrieveError errors // Upstream password grant errors can be generic errors (e.g. a network failure) or can be oauth2.RetrieveError errors
// which represent the http response from the upstream server. These could be a 5XX or some other unexpected error, // which represent the http response from the upstream server. These could be a 5XX or some other unexpected error,
@ -234,7 +264,7 @@ func handleAuthRequestForOIDCUpstreamPasswordGrant(
return nil return nil
} }
subject, username, groups, err := downstreamsession.GetDownstreamIdentityFromUpstreamIDToken(oidcUpstream, token.IDToken.Claims) subject, upstreamUsername, upstreamGroups, err := downstreamsession.GetDownstreamIdentityFromUpstreamIDToken(oidcUpstream, token.IDToken.Claims)
if err != nil { if err != nil {
// Return a user-friendly error for this case which is entirely within our control. // Return a user-friendly error for this case which is entirely within our control.
oidc.WriteAuthorizeError(r, w, oauthHelper, authorizeRequester, oidc.WriteAuthorizeError(r, w, oauthHelper, authorizeRequester,
@ -243,9 +273,17 @@ func handleAuthRequestForOIDCUpstreamPasswordGrant(
return nil return nil
} }
username, groups, err := downstreamsession.ApplyIdentityTransformations(r.Context(), identityTransforms, upstreamUsername, upstreamGroups)
if err != nil {
oidc.WriteAuthorizeError(r, w, oauthHelper, authorizeRequester,
fosite.ErrAccessDenied.WithHintf("Reason: %s.", err.Error()), true,
)
return nil
}
additionalClaims := downstreamsession.MapAdditionalClaimsFromUpstreamIDToken(oidcUpstream, token.IDToken.Claims) additionalClaims := downstreamsession.MapAdditionalClaimsFromUpstreamIDToken(oidcUpstream, token.IDToken.Claims)
customSessionData, err := downstreamsession.MakeDownstreamOIDCCustomSessionData(oidcUpstream, token, username) customSessionData, err := downstreamsession.MakeDownstreamOIDCCustomSessionData(oidcUpstream, token, username, upstreamUsername, upstreamGroups)
if err != nil { if err != nil {
oidc.WriteAuthorizeError(r, w, oauthHelper, authorizeRequester, oidc.WriteAuthorizeError(r, w, oauthHelper, authorizeRequester,
fosite.ErrAccessDenied.WithHintf("Reason: %s.", err.Error()), true, fosite.ErrAccessDenied.WithHintf("Reason: %s.", err.Error()), true,
@ -268,10 +306,11 @@ func handleAuthRequestForOIDCUpstreamBrowserFlow(
generateCSRF func() (csrftoken.CSRFToken, error), generateCSRF func() (csrftoken.CSRFToken, error),
generateNonce func() (nonce.Nonce, error), generateNonce func() (nonce.Nonce, error),
generatePKCE func() (pkce.Code, error), generatePKCE func() (pkce.Code, error),
oidcUpstream provider.UpstreamOIDCIdentityProviderI, oidcUpstream *provider.FederationDomainResolvedOIDCIdentityProvider,
downstreamIssuer string, downstreamIssuer string,
upstreamStateEncoder oidc.Encoder, upstreamStateEncoder oidc.Encoder,
cookieCodec oidc.Codec, cookieCodec oidc.Codec,
idpNameQueryParamValue string,
) error { ) error {
authRequestState, err := handleBrowserFlowAuthRequest( authRequestState, err := handleBrowserFlowAuthRequest(
r, r,
@ -280,10 +319,11 @@ func handleAuthRequestForOIDCUpstreamBrowserFlow(
generateCSRF, generateCSRF,
generateNonce, generateNonce,
generatePKCE, generatePKCE,
oidcUpstream.GetName(), oidcUpstream.DisplayName,
psession.ProviderTypeOIDC, psession.ProviderTypeOIDC,
cookieCodec, cookieCodec,
upstreamStateEncoder, upstreamStateEncoder,
idpNameQueryParamValue,
) )
if err != nil { if err != nil {
return err return err
@ -294,12 +334,12 @@ func handleAuthRequestForOIDCUpstreamBrowserFlow(
} }
upstreamOAuthConfig := oauth2.Config{ upstreamOAuthConfig := oauth2.Config{
ClientID: oidcUpstream.GetClientID(), ClientID: oidcUpstream.Provider.GetClientID(),
Endpoint: oauth2.Endpoint{ Endpoint: oauth2.Endpoint{
AuthURL: oidcUpstream.GetAuthorizationURL().String(), AuthURL: oidcUpstream.Provider.GetAuthorizationURL().String(),
}, },
RedirectURL: fmt.Sprintf("%s/callback", downstreamIssuer), RedirectURL: fmt.Sprintf("%s/callback", downstreamIssuer),
Scopes: oidcUpstream.GetScopes(), Scopes: oidcUpstream.Provider.GetScopes(),
} }
authCodeOptions := []oauth2.AuthCodeOption{ authCodeOptions := []oauth2.AuthCodeOption{
@ -308,7 +348,7 @@ func handleAuthRequestForOIDCUpstreamBrowserFlow(
authRequestState.pkce.Method(), authRequestState.pkce.Method(),
} }
for key, val := range oidcUpstream.GetAdditionalAuthcodeParams() { for key, val := range oidcUpstream.Provider.GetAdditionalAuthcodeParams() {
authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam(key, val)) authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam(key, val))
} }
@ -382,39 +422,31 @@ func readCSRFCookie(r *http.Request, codec oidc.Decoder) csrftoken.CSRFToken {
// chooseUpstreamIDP selects either an OIDC, an LDAP, or an AD IDP, or returns an error. // chooseUpstreamIDP selects either an OIDC, an LDAP, or an AD IDP, or returns an error.
// Note that AD and LDAP IDPs both return the same interface type, but different ProviderTypes values. // Note that AD and LDAP IDPs both return the same interface type, but different ProviderTypes values.
func chooseUpstreamIDP(idpLister oidc.UpstreamIdentityProvidersLister) (provider.UpstreamOIDCIdentityProviderI, provider.UpstreamLDAPIdentityProviderI, psession.ProviderType, error) { func chooseUpstreamIDP(idpDisplayName string, idpLister provider.FederationDomainIdentityProvidersFinderI) (*provider.FederationDomainResolvedOIDCIdentityProvider, *provider.FederationDomainResolvedLDAPIdentityProvider, error) {
oidcUpstreams := idpLister.GetOIDCIdentityProviders() // When a request is made to the authorization endpoint which does not specify the IDP name, then it might
ldapUpstreams := idpLister.GetLDAPIdentityProviders() // be an old dynamic client (OIDCClient). We need to make this work, but only in the backwards compatibility case
adUpstreams := idpLister.GetActiveDirectoryIdentityProviders() // where there is exactly one IDP defined in the namespace and no IDPs listed on the FederationDomain.
switch { // This backwards compatibility mode is handled by FindDefaultIDP().
case len(oidcUpstreams)+len(ldapUpstreams)+len(adUpstreams) == 0: if len(idpDisplayName) == 0 {
return nil, nil, "", httperr.New( return idpLister.FindDefaultIDP()
http.StatusUnprocessableEntity,
"No upstream providers are configured",
)
case len(oidcUpstreams)+len(ldapUpstreams)+len(adUpstreams) > 1:
var upstreamIDPNames []string
for _, idp := range oidcUpstreams {
upstreamIDPNames = append(upstreamIDPNames, idp.GetName())
}
for _, idp := range ldapUpstreams {
upstreamIDPNames = append(upstreamIDPNames, idp.GetName())
}
for _, idp := range adUpstreams {
upstreamIDPNames = append(upstreamIDPNames, idp.GetName())
}
plog.Warning("Too many upstream providers are configured (found: %s)", upstreamIDPNames)
return nil, nil, "", httperr.New(
http.StatusUnprocessableEntity,
"Too many upstream providers are configured (support for multiple upstreams is not yet implemented)",
)
case len(oidcUpstreams) == 1:
return oidcUpstreams[0], nil, psession.ProviderTypeOIDC, nil
case len(adUpstreams) == 1:
return nil, adUpstreams[0], psession.ProviderTypeActiveDirectory, nil
default:
return nil, ldapUpstreams[0], psession.ProviderTypeLDAP, nil
} }
return idpLister.FindUpstreamIDPByDisplayName(idpDisplayName)
}
func maybeLogDeprecationWarningForMissingIDPParam(idpNameQueryParamValue string, authorizeRequester fosite.AuthorizeRequester) {
if len(idpNameQueryParamValue) != 0 {
return
}
plog.Warning("Client attempted to perform an authorization flow (user login) without specifying the "+
"query param to choose an identity provider. "+
"This will not work when identity providers are configured explicitly on a FederationDomain. "+
"Additionally, this behavior is deprecated and support for any authorization requests missing this query param "+
"may be removed in a future release. "+
"Please ask the author of this client to update the authorization request URL to include this query parameter. "+
"The value of the parameter should be equal to the displayName of the identity provider as declared in the FederationDomain.",
"missingParameterName", oidcapi.AuthorizeUpstreamIDPNameParamName,
"clientID", authorizeRequester.GetClient().GetID(),
)
} }
type browserFlowAuthRequestState struct { type browserFlowAuthRequestState struct {
@ -438,16 +470,19 @@ func handleBrowserFlowAuthRequest(
generateCSRF func() (csrftoken.CSRFToken, error), generateCSRF func() (csrftoken.CSRFToken, error),
generateNonce func() (nonce.Nonce, error), generateNonce func() (nonce.Nonce, error),
generatePKCE func() (pkce.Code, error), generatePKCE func() (pkce.Code, error),
upstreamName string, upstreamDisplayName string,
idpType psession.ProviderType, idpType psession.ProviderType,
cookieCodec oidc.Codec, cookieCodec oidc.Codec,
upstreamStateEncoder oidc.Encoder, upstreamStateEncoder oidc.Encoder,
idpNameQueryParamValue string,
) (*browserFlowAuthRequestState, error) { ) (*browserFlowAuthRequestState, error) {
authorizeRequester, created := newAuthorizeRequest(r, w, oauthHelper, false) authorizeRequester, created := newAuthorizeRequest(r, w, oauthHelper, false)
if !created { if !created {
return nil, nil // already wrote the error response, don't return error return nil, nil // already wrote the error response, don't return error
} }
maybeLogDeprecationWarningForMissingIDPParam(idpNameQueryParamValue, authorizeRequester)
now := time.Now() now := time.Now()
_, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, &psession.PinnipedSession{ _, err := oauthHelper.NewAuthorizeResponse(r.Context(), authorizeRequester, &psession.PinnipedSession{
Fosite: &openid.DefaultSession{ Fosite: &openid.DefaultSession{
@ -476,7 +511,7 @@ func handleBrowserFlowAuthRequest(
encodedStateParamValue, err := upstreamStateParam( encodedStateParamValue, err := upstreamStateParam(
authorizeRequester, authorizeRequester,
upstreamName, upstreamDisplayName,
string(idpType), string(idpType),
nonceValue, nonceValue,
csrfValue, csrfValue,
@ -532,7 +567,7 @@ func generateValues(
func upstreamStateParam( func upstreamStateParam(
authorizeRequester fosite.AuthorizeRequester, authorizeRequester fosite.AuthorizeRequester,
upstreamName string, upstreamDisplayName string,
upstreamType string, upstreamType string,
nonceValue nonce.Nonce, nonceValue nonce.Nonce,
csrfValue csrftoken.CSRFToken, csrfValue csrftoken.CSRFToken,
@ -546,7 +581,7 @@ func upstreamStateParam(
// The UpstreamName and UpstreamType struct fields can be used instead. // The UpstreamName and UpstreamType struct fields can be used instead.
// Remove those params here to avoid potential confusion about which should be used later. // Remove those params here to avoid potential confusion about which should be used later.
AuthParams: removeCustomIDPParams(authorizeRequester.GetRequestForm()).Encode(), AuthParams: removeCustomIDPParams(authorizeRequester.GetRequestForm()).Encode(),
UpstreamName: upstreamName, UpstreamName: upstreamDisplayName,
UpstreamType: upstreamType, UpstreamType: upstreamType,
Nonce: nonceValue, Nonce: nonceValue,
CSRFToken: csrfValue, CSRFToken: csrfValue,

View File

@ -35,7 +35,7 @@ import (
"go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/jwks"
"go.pinniped.dev/internal/oidc/oidcclientvalidator" "go.pinniped.dev/internal/oidc/oidcclientvalidator"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/psession"
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
"go.pinniped.dev/internal/testutil/oidctestutil" "go.pinniped.dev/internal/testutil/oidctestutil"
@ -3287,7 +3287,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
WithScopes([]string{"some-other-new-scope1", "some-other-new-scope2"}). WithScopes([]string{"some-other-new-scope1", "some-other-new-scope2"}).
WithAdditionalAuthcodeParams(map[string]string{"prompt": "consent", "abc": "123"}). WithAdditionalAuthcodeParams(map[string]string{"prompt": "consent", "abc": "123"}).
Build() Build()
idpLister.SetOIDCIdentityProviders([]provider.UpstreamOIDCIdentityProviderI{provider.UpstreamOIDCIdentityProviderI(newProviderSettings)}) idpLister.SetOIDCIdentityProviders([]upstreamprovider.UpstreamOIDCIdentityProviderI{upstreamprovider.UpstreamOIDCIdentityProviderI(newProviderSettings)})
// Update the expectations of the test case to match the new upstream IDP settings. // Update the expectations of the test case to match the new upstream IDP settings.
test.wantLocationHeader = urlWithQuery(upstreamAuthURL.String(), test.wantLocationHeader = urlWithQuery(upstreamAuthURL.String(),

View File

@ -20,7 +20,7 @@ import (
) )
func NewHandler( func NewHandler(
upstreamIDPs oidc.UpstreamOIDCIdentityProvidersLister, upstreamIDPs provider.FederationDomainIdentityProvidersFinderI,
oauthHelper fosite.OAuth2Provider, oauthHelper fosite.OAuth2Provider,
stateDecoder, cookieDecoder oidc.Decoder, stateDecoder, cookieDecoder oidc.Decoder,
redirectURI string, redirectURI string,
@ -31,11 +31,12 @@ func NewHandler(
return err return err
} }
upstreamIDPConfig := findUpstreamIDPConfig(state.UpstreamName, upstreamIDPs) resolvedOIDCIdentityProvider, _, err := upstreamIDPs.FindUpstreamIDPByDisplayName(state.UpstreamName)
if upstreamIDPConfig == nil { if err != nil || resolvedOIDCIdentityProvider == nil {
plog.Warning("upstream provider not found") plog.Warning("upstream provider not found")
return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found") return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found")
} }
upstreamIDPConfig := resolvedOIDCIdentityProvider.Provider
downstreamAuthParams, err := url.ParseQuery(state.AuthParams) downstreamAuthParams, err := url.ParseQuery(state.AuthParams)
if err != nil { if err != nil {
@ -69,14 +70,19 @@ func NewHandler(
return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens") return httperr.New(http.StatusBadGateway, "error exchanging and validating upstream tokens")
} }
subject, username, groups, err := downstreamsession.GetDownstreamIdentityFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims) subject, upstreamUsername, upstreamGroups, err := downstreamsession.GetDownstreamIdentityFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims)
if err != nil {
return httperr.Wrap(http.StatusUnprocessableEntity, err.Error(), err)
}
username, groups, err := downstreamsession.ApplyIdentityTransformations(r.Context(), resolvedOIDCIdentityProvider.Transforms, upstreamUsername, upstreamGroups)
if err != nil { if err != nil {
return httperr.Wrap(http.StatusUnprocessableEntity, err.Error(), err) return httperr.Wrap(http.StatusUnprocessableEntity, err.Error(), err)
} }
additionalClaims := downstreamsession.MapAdditionalClaimsFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims) additionalClaims := downstreamsession.MapAdditionalClaimsFromUpstreamIDToken(upstreamIDPConfig, token.IDToken.Claims)
customSessionData, err := downstreamsession.MakeDownstreamOIDCCustomSessionData(upstreamIDPConfig, token, username) customSessionData, err := downstreamsession.MakeDownstreamOIDCCustomSessionData(upstreamIDPConfig, token, username, upstreamUsername, upstreamGroups)
if err != nil { if err != nil {
return httperr.Wrap(http.StatusUnprocessableEntity, err.Error(), err) return httperr.Wrap(http.StatusUnprocessableEntity, err.Error(), err)
} }
@ -120,12 +126,3 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder)
return decodedState, nil return decodedState, nil
} }
func findUpstreamIDPConfig(upstreamName string, upstreamIDPs oidc.UpstreamOIDCIdentityProvidersLister) provider.UpstreamOIDCIdentityProviderI {
for _, p := range upstreamIDPs.GetOIDCIdentityProviders() {
if p.GetName() == upstreamName {
return p
}
}
return nil
}

View File

@ -5,6 +5,7 @@
package downstreamsession package downstreamsession
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"net/url" "net/url"
@ -19,8 +20,9 @@ import (
oidcapi "go.pinniped.dev/generated/latest/apis/supervisor/oidc" oidcapi "go.pinniped.dev/generated/latest/apis/supervisor/oidc"
"go.pinniped.dev/internal/authenticators" "go.pinniped.dev/internal/authenticators"
"go.pinniped.dev/internal/constable" "go.pinniped.dev/internal/constable"
"go.pinniped.dev/internal/idtransform"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/psession"
"go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/oidctypes"
@ -38,6 +40,8 @@ const (
requiredClaimEmptyErr = constable.Error("required claim in upstream ID token is empty") requiredClaimEmptyErr = constable.Error("required claim in upstream ID token is empty")
emailVerifiedClaimInvalidFormatErr = constable.Error("email_verified claim in upstream ID token has invalid format") emailVerifiedClaimInvalidFormatErr = constable.Error("email_verified claim in upstream ID token has invalid format")
emailVerifiedClaimFalseErr = constable.Error("email_verified claim in upstream ID token has false value") emailVerifiedClaimFalseErr = constable.Error("email_verified claim in upstream ID token has false value")
idTransformUnexpectedErr = constable.Error("configured identity transformation or policy resulted in unexpected error")
idTransformPolicyErr = constable.Error("configured identity policy rejected this authentication")
) )
// MakeDownstreamSession creates a downstream OIDC session. // MakeDownstreamSession creates a downstream OIDC session.
@ -82,16 +86,20 @@ func MakeDownstreamSession(
} }
func MakeDownstreamLDAPOrADCustomSessionData( func MakeDownstreamLDAPOrADCustomSessionData(
ldapUpstream provider.UpstreamLDAPIdentityProviderI, ldapUpstream upstreamprovider.UpstreamLDAPIdentityProviderI,
idpType psession.ProviderType, idpType psession.ProviderType,
authenticateResponse *authenticators.Response, authenticateResponse *authenticators.Response,
username string, username string,
untransformedUpstreamUsername string,
untransformedUpstreamGroups []string,
) *psession.CustomSessionData { ) *psession.CustomSessionData {
customSessionData := &psession.CustomSessionData{ customSessionData := &psession.CustomSessionData{
Username: username, Username: username,
ProviderUID: ldapUpstream.GetResourceUID(), UpstreamUsername: untransformedUpstreamUsername,
ProviderName: ldapUpstream.GetName(), UpstreamGroups: untransformedUpstreamGroups,
ProviderType: idpType, ProviderUID: ldapUpstream.GetResourceUID(),
ProviderName: ldapUpstream.GetName(),
ProviderType: idpType,
} }
if idpType == psession.ProviderTypeLDAP { if idpType == psession.ProviderTypeLDAP {
@ -112,9 +120,11 @@ func MakeDownstreamLDAPOrADCustomSessionData(
} }
func MakeDownstreamOIDCCustomSessionData( func MakeDownstreamOIDCCustomSessionData(
oidcUpstream provider.UpstreamOIDCIdentityProviderI, oidcUpstream upstreamprovider.UpstreamOIDCIdentityProviderI,
token *oidctypes.Token, token *oidctypes.Token,
username string, username string,
untransformedUpstreamUsername string,
untransformedUpstreamGroups []string,
) (*psession.CustomSessionData, error) { ) (*psession.CustomSessionData, error) {
upstreamSubject, err := ExtractStringClaimValue(oidcapi.IDTokenClaimSubject, oidcUpstream.GetName(), token.IDToken.Claims) upstreamSubject, err := ExtractStringClaimValue(oidcapi.IDTokenClaimSubject, oidcUpstream.GetName(), token.IDToken.Claims)
if err != nil { if err != nil {
@ -126,10 +136,12 @@ func MakeDownstreamOIDCCustomSessionData(
} }
customSessionData := &psession.CustomSessionData{ customSessionData := &psession.CustomSessionData{
Username: username, Username: username,
ProviderUID: oidcUpstream.GetResourceUID(), UpstreamUsername: untransformedUpstreamUsername,
ProviderName: oidcUpstream.GetName(), UpstreamGroups: untransformedUpstreamGroups,
ProviderType: psession.ProviderTypeOIDC, ProviderUID: oidcUpstream.GetResourceUID(),
ProviderName: oidcUpstream.GetName(),
ProviderType: psession.ProviderTypeOIDC,
OIDC: &psession.OIDCSessionData{ OIDC: &psession.OIDCSessionData{
UpstreamIssuer: upstreamIssuer, UpstreamIssuer: upstreamIssuer,
UpstreamSubject: upstreamSubject, UpstreamSubject: upstreamSubject,
@ -200,7 +212,7 @@ func AutoApproveScopes(authorizeRequester fosite.AuthorizeRequester) {
// GetDownstreamIdentityFromUpstreamIDToken returns the mapped subject, username, and group names, in that order. // GetDownstreamIdentityFromUpstreamIDToken returns the mapped subject, username, and group names, in that order.
func GetDownstreamIdentityFromUpstreamIDToken( func GetDownstreamIdentityFromUpstreamIDToken(
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI, upstreamIDPConfig upstreamprovider.UpstreamOIDCIdentityProviderI,
idTokenClaims map[string]interface{}, idTokenClaims map[string]interface{},
) (string, string, []string, error) { ) (string, string, []string, error) {
subject, username, err := getSubjectAndUsernameFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims) subject, username, err := getSubjectAndUsernameFromUpstreamIDToken(upstreamIDPConfig, idTokenClaims)
@ -218,7 +230,7 @@ func GetDownstreamIdentityFromUpstreamIDToken(
// MapAdditionalClaimsFromUpstreamIDToken returns the additionalClaims mapped from the upstream token, if any. // MapAdditionalClaimsFromUpstreamIDToken returns the additionalClaims mapped from the upstream token, if any.
func MapAdditionalClaimsFromUpstreamIDToken( func MapAdditionalClaimsFromUpstreamIDToken(
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI, upstreamIDPConfig upstreamprovider.UpstreamOIDCIdentityProviderI,
idTokenClaims map[string]interface{}, idTokenClaims map[string]interface{},
) map[string]interface{} { ) map[string]interface{} {
mapped := make(map[string]interface{}, len(upstreamIDPConfig.GetAdditionalClaimMappings())) mapped := make(map[string]interface{}, len(upstreamIDPConfig.GetAdditionalClaimMappings()))
@ -237,8 +249,32 @@ func MapAdditionalClaimsFromUpstreamIDToken(
return mapped return mapped
} }
func ApplyIdentityTransformations(
ctx context.Context,
identityTransforms *idtransform.TransformationPipeline,
username string,
groups []string,
) (string, []string, error) {
transformationResult, err := identityTransforms.Evaluate(ctx, username, groups)
if err != nil {
plog.Error("unexpected identity transformation error during authentication", err, "inputUsername", username)
return "", nil, idTransformUnexpectedErr
}
if !transformationResult.AuthenticationAllowed {
plog.Debug("authentication rejected by configured policy", "inputUsername", username, "inputGroups", groups)
return "", nil, idTransformPolicyErr
}
plog.Debug("identity transformation successfully applied during authentication",
"originalUsername", username,
"newUsername", transformationResult.Username,
"originalGroups", groups,
"newGroups", transformationResult.Groups,
)
return transformationResult.Username, transformationResult.Groups, nil
}
func getSubjectAndUsernameFromUpstreamIDToken( func getSubjectAndUsernameFromUpstreamIDToken(
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI, upstreamIDPConfig upstreamprovider.UpstreamOIDCIdentityProviderI,
idTokenClaims map[string]interface{}, idTokenClaims map[string]interface{},
) (string, string, error) { ) (string, string, error) {
// The spec says the "sub" claim is only unique per issuer, // The spec says the "sub" claim is only unique per issuer,
@ -323,7 +359,7 @@ func ExtractStringClaimValue(claimName string, upstreamIDPName string, idTokenCl
return valueAsString, nil return valueAsString, nil
} }
func DownstreamSubjectFromUpstreamLDAP(ldapUpstream provider.UpstreamLDAPIdentityProviderI, authenticateResponse *authenticators.Response) string { func DownstreamSubjectFromUpstreamLDAP(ldapUpstream upstreamprovider.UpstreamLDAPIdentityProviderI, authenticateResponse *authenticators.Response) string {
ldapURL := *ldapUpstream.GetURL() ldapURL := *ldapUpstream.GetURL()
return DownstreamLDAPSubject(authenticateResponse.User.GetUID(), ldapURL) return DownstreamLDAPSubject(authenticateResponse.User.GetUID(), ldapURL)
} }
@ -343,7 +379,7 @@ func downstreamSubjectFromUpstreamOIDC(upstreamIssuerAsString string, upstreamSu
// It returns nil when there is no configured groups claim name, or then when the configured claim name is not found // It returns nil when there is no configured groups claim name, or then when the configured claim name is not found
// in the provided map of claims. It returns an error when the claim exists but its value cannot be parsed. // in the provided map of claims. It returns an error when the claim exists but its value cannot be parsed.
func GetGroupsFromUpstreamIDToken( func GetGroupsFromUpstreamIDToken(
upstreamIDPConfig provider.UpstreamOIDCIdentityProviderI, upstreamIDPConfig upstreamprovider.UpstreamOIDCIdentityProviderI,
idTokenClaims map[string]interface{}, idTokenClaims map[string]interface{},
) ([]string, error) { ) ([]string, error) {
groupsClaimName := upstreamIDPConfig.GetGroupsClaim() groupsClaimName := upstreamIDPConfig.GetGroupsClaim()

View File

@ -1,4 +1,4 @@
// Copyright 2021-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2021-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// Package idpdiscovery provides a handler for the upstream IDP discovery endpoint. // Package idpdiscovery provides a handler for the upstream IDP discovery endpoint.
@ -11,11 +11,11 @@ import (
"sort" "sort"
"go.pinniped.dev/generated/latest/apis/supervisor/idpdiscovery/v1alpha1" "go.pinniped.dev/generated/latest/apis/supervisor/idpdiscovery/v1alpha1"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc/provider"
) )
// NewHandler returns an http.Handler that serves the upstream IDP discovery endpoint. // NewHandler returns an http.Handler that serves the upstream IDP discovery endpoint.
func NewHandler(upstreamIDPs oidc.UpstreamIdentityProvidersLister) http.Handler { func NewHandler(upstreamIDPs provider.FederationDomainIdentityProvidersListerI) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
http.Error(w, `Method not allowed (try GET)`, http.StatusMethodNotAllowed) http.Error(w, `Method not allowed (try GET)`, http.StatusMethodNotAllowed)
@ -36,31 +36,31 @@ func NewHandler(upstreamIDPs oidc.UpstreamIdentityProvidersLister) http.Handler
}) })
} }
func responseAsJSON(upstreamIDPs oidc.UpstreamIdentityProvidersLister) ([]byte, error) { func responseAsJSON(upstreamIDPs provider.FederationDomainIdentityProvidersListerI) ([]byte, error) {
r := v1alpha1.IDPDiscoveryResponse{PinnipedIDPs: []v1alpha1.PinnipedIDP{}} r := v1alpha1.IDPDiscoveryResponse{PinnipedIDPs: []v1alpha1.PinnipedIDP{}}
// The cache of IDPs could change at any time, so always recalculate the list. // The cache of IDPs could change at any time, so always recalculate the list.
for _, provider := range upstreamIDPs.GetLDAPIdentityProviders() { for _, federationDomainIdentityProvider := range upstreamIDPs.GetLDAPIdentityProviders() {
r.PinnipedIDPs = append(r.PinnipedIDPs, v1alpha1.PinnipedIDP{ r.PinnipedIDPs = append(r.PinnipedIDPs, v1alpha1.PinnipedIDP{
Name: provider.GetName(), Name: federationDomainIdentityProvider.DisplayName,
Type: v1alpha1.IDPTypeLDAP, Type: v1alpha1.IDPTypeLDAP,
Flows: []v1alpha1.IDPFlow{v1alpha1.IDPFlowCLIPassword, v1alpha1.IDPFlowBrowserAuthcode}, Flows: []v1alpha1.IDPFlow{v1alpha1.IDPFlowCLIPassword, v1alpha1.IDPFlowBrowserAuthcode},
}) })
} }
for _, provider := range upstreamIDPs.GetActiveDirectoryIdentityProviders() { for _, federationDomainIdentityProvider := range upstreamIDPs.GetActiveDirectoryIdentityProviders() {
r.PinnipedIDPs = append(r.PinnipedIDPs, v1alpha1.PinnipedIDP{ r.PinnipedIDPs = append(r.PinnipedIDPs, v1alpha1.PinnipedIDP{
Name: provider.GetName(), Name: federationDomainIdentityProvider.DisplayName,
Type: v1alpha1.IDPTypeActiveDirectory, Type: v1alpha1.IDPTypeActiveDirectory,
Flows: []v1alpha1.IDPFlow{v1alpha1.IDPFlowCLIPassword, v1alpha1.IDPFlowBrowserAuthcode}, Flows: []v1alpha1.IDPFlow{v1alpha1.IDPFlowCLIPassword, v1alpha1.IDPFlowBrowserAuthcode},
}) })
} }
for _, provider := range upstreamIDPs.GetOIDCIdentityProviders() { for _, federationDomainIdentityProvider := range upstreamIDPs.GetOIDCIdentityProviders() {
flows := []v1alpha1.IDPFlow{v1alpha1.IDPFlowBrowserAuthcode} flows := []v1alpha1.IDPFlow{v1alpha1.IDPFlowBrowserAuthcode}
if provider.AllowsPasswordGrant() { if federationDomainIdentityProvider.Provider.AllowsPasswordGrant() {
flows = append(flows, v1alpha1.IDPFlowCLIPassword) flows = append(flows, v1alpha1.IDPFlowCLIPassword)
} }
r.PinnipedIDPs = append(r.PinnipedIDPs, v1alpha1.PinnipedIDP{ r.PinnipedIDPs = append(r.PinnipedIDPs, v1alpha1.PinnipedIDP{
Name: provider.GetName(), Name: federationDomainIdentityProvider.DisplayName,
Type: v1alpha1.IDPTypeOIDC, Type: v1alpha1.IDPTypeOIDC,
Flows: flows, Flows: flows,
}) })

View File

@ -1,4 +1,4 @@
// Copyright 2021-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2021-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package idpdiscovery package idpdiscovery
@ -12,7 +12,7 @@ import (
"go.pinniped.dev/internal/here" "go.pinniped.dev/internal/here"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/testutil/oidctestutil" "go.pinniped.dev/internal/testutil/oidctestutil"
) )
@ -99,16 +99,16 @@ func TestIDPDiscovery(t *testing.T) {
} }
// Change the list of IDPs in the cache. // Change the list of IDPs in the cache.
idpLister.SetLDAPIdentityProviders([]provider.UpstreamLDAPIdentityProviderI{ idpLister.SetLDAPIdentityProviders([]upstreamprovider.UpstreamLDAPIdentityProviderI{
&oidctestutil.TestUpstreamLDAPIdentityProvider{Name: "some-other-ldap-idp-1"}, &oidctestutil.TestUpstreamLDAPIdentityProvider{Name: "some-other-ldap-idp-1"},
&oidctestutil.TestUpstreamLDAPIdentityProvider{Name: "some-other-ldap-idp-2"}, &oidctestutil.TestUpstreamLDAPIdentityProvider{Name: "some-other-ldap-idp-2"},
}) })
idpLister.SetOIDCIdentityProviders([]provider.UpstreamOIDCIdentityProviderI{ idpLister.SetOIDCIdentityProviders([]upstreamprovider.UpstreamOIDCIdentityProviderI{
&oidctestutil.TestUpstreamOIDCIdentityProvider{Name: "some-other-oidc-idp-1", AllowPasswordGrant: true}, &oidctestutil.TestUpstreamOIDCIdentityProvider{Name: "some-other-oidc-idp-1", AllowPasswordGrant: true},
&oidctestutil.TestUpstreamOIDCIdentityProvider{Name: "some-other-oidc-idp-2"}, &oidctestutil.TestUpstreamOIDCIdentityProvider{Name: "some-other-oidc-idp-2"},
}) })
idpLister.SetActiveDirectoryIdentityProviders([]provider.UpstreamLDAPIdentityProviderI{ idpLister.SetActiveDirectoryIdentityProviders([]upstreamprovider.UpstreamLDAPIdentityProviderI{
&oidctestutil.TestUpstreamLDAPIdentityProvider{Name: "some-other-ad-idp-2"}, &oidctestutil.TestUpstreamLDAPIdentityProvider{Name: "some-other-ad-idp-2"},
&oidctestutil.TestUpstreamLDAPIdentityProvider{Name: "some-other-ad-idp-1"}, &oidctestutil.TestUpstreamLDAPIdentityProvider{Name: "some-other-ad-idp-1"},
}) })

View File

@ -0,0 +1,26 @@
// Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package idplister
import (
"go.pinniped.dev/internal/oidc/provider/upstreamprovider"
)
type UpstreamOIDCIdentityProvidersLister interface {
GetOIDCIdentityProviders() []upstreamprovider.UpstreamOIDCIdentityProviderI
}
type UpstreamLDAPIdentityProvidersLister interface {
GetLDAPIdentityProviders() []upstreamprovider.UpstreamLDAPIdentityProviderI
}
type UpstreamActiveDirectoryIdentityProviderLister interface {
GetActiveDirectoryIdentityProviders() []upstreamprovider.UpstreamLDAPIdentityProviderI
}
type UpstreamIdentityProvidersLister interface {
UpstreamOIDCIdentityProvidersLister
UpstreamLDAPIdentityProvidersLister
UpstreamActiveDirectoryIdentityProviderLister
}

View File

@ -1,4 +1,4 @@
// Copyright 2022 the Pinniped contributors. All Rights Reserved. // Copyright 2022-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package login package login
@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/idplister"
"go.pinniped.dev/internal/oidc/login/loginhtml" "go.pinniped.dev/internal/oidc/login/loginhtml"
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
) )
@ -28,7 +29,7 @@ func TestGetLogin(t *testing.T) {
decodedState *oidc.UpstreamStateParamData decodedState *oidc.UpstreamStateParamData
encodedState string encodedState string
errParam string errParam string
idps oidc.UpstreamIdentityProvidersLister idps idplister.UpstreamIdentityProvidersLister
wantStatus int wantStatus int
wantContentType string wantContentType string
wantBody string wantBody string

View File

@ -12,13 +12,14 @@ import (
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/downstreamsession" "go.pinniped.dev/internal/oidc/downstreamsession"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
) )
func NewPostHandler(issuerURL string, upstreamIDPs oidc.UpstreamIdentityProvidersLister, oauthHelper fosite.OAuth2Provider) HandlerFunc { func NewPostHandler(issuerURL string, upstreamIDPs provider.FederationDomainIdentityProvidersFinderI, oauthHelper fosite.OAuth2Provider) HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, encodedState string, decodedState *oidc.UpstreamStateParamData) error { return func(w http.ResponseWriter, r *http.Request, encodedState string, decodedState *oidc.UpstreamStateParamData) error {
// Note that the login handler prevents this handler from being called with OIDC upstreams. // Note that the login handler prevents this handler from being called with OIDC upstreams.
_, ldapUpstream, idpType, err := oidc.FindUpstreamIDPByNameAndType(upstreamIDPs, decodedState.UpstreamName, decodedState.UpstreamType) _, ldapUpstream, err := upstreamIDPs.FindUpstreamIDPByDisplayName(decodedState.UpstreamName)
if err != nil { if err != nil {
// This shouldn't normally happen because the authorization endpoint ensured that this provider existed // This shouldn't normally happen because the authorization endpoint ensured that this provider existed
// at that time. It would be possible in the unlikely event that the provider was deleted during the login. // at that time. It would be possible in the unlikely event that the provider was deleted during the login.
@ -51,20 +52,20 @@ func NewPostHandler(issuerURL string, upstreamIDPs oidc.UpstreamIdentityProvider
downstreamsession.AutoApproveScopes(authorizeRequester) downstreamsession.AutoApproveScopes(authorizeRequester)
// Get the username and password form params from the POST body. // Get the username and password form params from the POST body.
username := r.PostFormValue(usernameParamName) submittedUsername := r.PostFormValue(usernameParamName)
password := r.PostFormValue(passwordParamName) submittedPassword := r.PostFormValue(passwordParamName)
// Treat blank username or password as a bad username/password combination, as opposed to an internal error. // Treat blank username or password as a bad username/password combination, as opposed to an internal error.
if username == "" || password == "" { if submittedUsername == "" || submittedPassword == "" {
// User forgot to enter one of the required fields. // User forgot to enter one of the required fields.
// The user may try to log in again if they'd like, so redirect back to the login page with an error. // The user may try to log in again if they'd like, so redirect back to the login page with an error.
return RedirectToLoginPage(r, w, issuerURL, encodedState, ShowBadUserPassErr) return RedirectToLoginPage(r, w, issuerURL, encodedState, ShowBadUserPassErr)
} }
// Attempt to authenticate the user with the upstream IDP. // Attempt to authenticate the user with the upstream IDP.
authenticateResponse, authenticated, err := ldapUpstream.AuthenticateUser(r.Context(), username, password, authorizeRequester.GetGrantedScopes()) authenticateResponse, authenticated, err := ldapUpstream.Provider.AuthenticateUser(r.Context(), submittedUsername, submittedPassword, authorizeRequester.GetGrantedScopes())
if err != nil { if err != nil {
plog.WarningErr("unexpected error during upstream LDAP authentication", err, "upstreamName", ldapUpstream.GetName()) plog.WarningErr("unexpected error during upstream LDAP authentication", err, "upstreamName", ldapUpstream.Provider.GetName())
// There was some problem during authentication with the upstream, aside from bad username/password. // There was some problem during authentication with the upstream, aside from bad username/password.
// The user may try to log in again if they'd like, so redirect back to the login page with an error. // The user may try to log in again if they'd like, so redirect back to the login page with an error.
return RedirectToLoginPage(r, w, issuerURL, encodedState, ShowInternalError) return RedirectToLoginPage(r, w, issuerURL, encodedState, ShowInternalError)
@ -79,10 +80,19 @@ func NewPostHandler(issuerURL string, upstreamIDPs oidc.UpstreamIdentityProvider
// Now the upstream IDP has authenticated the user, so now we're back into the regular OIDC authcode flow steps. // Now the upstream IDP has authenticated the user, so now we're back into the regular OIDC authcode flow steps.
// Both success and error responses from this point onwards should look like the usual fosite redirect // Both success and error responses from this point onwards should look like the usual fosite redirect
// responses, and a happy redirect response will include a downstream authcode. // responses, and a happy redirect response will include a downstream authcode.
subject := downstreamsession.DownstreamSubjectFromUpstreamLDAP(ldapUpstream, authenticateResponse) subject := downstreamsession.DownstreamSubjectFromUpstreamLDAP(ldapUpstream.Provider, authenticateResponse)
username = authenticateResponse.User.GetName() upstreamUsername := authenticateResponse.User.GetName()
groups := authenticateResponse.User.GetGroups() upstreamGroups := authenticateResponse.User.GetGroups()
customSessionData := downstreamsession.MakeDownstreamLDAPOrADCustomSessionData(ldapUpstream, idpType, authenticateResponse, username)
username, groups, err := downstreamsession.ApplyIdentityTransformations(r.Context(), ldapUpstream.Transforms, upstreamUsername, upstreamGroups)
if err != nil {
oidc.WriteAuthorizeError(r, w, oauthHelper, authorizeRequester,
fosite.ErrAccessDenied.WithHintf("Reason: %s.", err.Error()), false,
)
return nil
}
customSessionData := downstreamsession.MakeDownstreamLDAPOrADCustomSessionData(ldapUpstream.Provider, ldapUpstream.SessionProviderType, authenticateResponse, username, upstreamUsername, upstreamGroups)
openIDSession := downstreamsession.MakeDownstreamSession(subject, username, groups, openIDSession := downstreamsession.MakeDownstreamSession(subject, username, groups,
authorizeRequester.GetGrantedScopes(), authorizeRequester.GetClient().GetID(), customSessionData, map[string]interface{}{}) authorizeRequester.GetGrantedScopes(), authorizeRequester.GetClient().GetID(), customSessionData, map[string]interface{}{})
oidc.PerformAuthcodeRedirect(r, w, oauthHelper, authorizeRequester, openIDSession, false) oidc.PerformAuthcodeRedirect(r, w, oauthHelper, authorizeRequester, openIDSession, false)

View File

@ -1,4 +1,4 @@
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// Package oidc contains common OIDC functionality needed by Pinniped. // Package oidc contains common OIDC functionality needed by Pinniped.
@ -7,7 +7,6 @@ package oidc
import ( import (
"context" "context"
"crypto/subtle" "crypto/subtle"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
@ -18,12 +17,10 @@ import (
"github.com/ory/fosite/compose" "github.com/ory/fosite/compose"
errorsx "github.com/pkg/errors" errorsx "github.com/pkg/errors"
"go.pinniped.dev/generated/latest/apis/supervisor/idpdiscovery/v1alpha1"
oidcapi "go.pinniped.dev/generated/latest/apis/supervisor/oidc" oidcapi "go.pinniped.dev/generated/latest/apis/supervisor/oidc"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/jwks"
"go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidc/provider/formposthtml" "go.pinniped.dev/internal/oidc/provider/formposthtml"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/psession"
@ -279,24 +276,6 @@ func FositeErrorForLog(err error) []interface{} {
return keysAndValues return keysAndValues
} }
type UpstreamOIDCIdentityProvidersLister interface {
GetOIDCIdentityProviders() []provider.UpstreamOIDCIdentityProviderI
}
type UpstreamLDAPIdentityProvidersLister interface {
GetLDAPIdentityProviders() []provider.UpstreamLDAPIdentityProviderI
}
type UpstreamActiveDirectoryIdentityProviderLister interface {
GetActiveDirectoryIdentityProviders() []provider.UpstreamLDAPIdentityProviderI
}
type UpstreamIdentityProvidersLister interface {
UpstreamOIDCIdentityProvidersLister
UpstreamLDAPIdentityProvidersLister
UpstreamActiveDirectoryIdentityProviderLister
}
func GrantScopeIfRequested(authorizeRequester fosite.AuthorizeRequester, scopeName string) { func GrantScopeIfRequested(authorizeRequester fosite.AuthorizeRequester, scopeName string) {
if ScopeWasRequested(authorizeRequester, scopeName) { if ScopeWasRequested(authorizeRequester, scopeName) {
authorizeRequester.GrantScope(scopeName) authorizeRequester.GrantScope(scopeName)
@ -377,41 +356,6 @@ func validateCSRFValue(state *UpstreamStateParamData, csrfCookieValue csrftoken.
return nil return nil
} }
// FindUpstreamIDPByNameAndType finds the requested IDP by name and type, or returns an error.
// Note that AD and LDAP IDPs both return the same interface type, but different ProviderTypes values.
func FindUpstreamIDPByNameAndType(
idpLister UpstreamIdentityProvidersLister,
upstreamName string,
upstreamType string,
) (
provider.UpstreamOIDCIdentityProviderI,
provider.UpstreamLDAPIdentityProviderI,
psession.ProviderType,
error,
) {
switch upstreamType {
case string(v1alpha1.IDPTypeOIDC):
for _, p := range idpLister.GetOIDCIdentityProviders() {
if p.GetName() == upstreamName {
return p, nil, psession.ProviderTypeOIDC, nil
}
}
case string(v1alpha1.IDPTypeLDAP):
for _, p := range idpLister.GetLDAPIdentityProviders() {
if p.GetName() == upstreamName {
return nil, p, psession.ProviderTypeLDAP, nil
}
}
case string(v1alpha1.IDPTypeActiveDirectory):
for _, p := range idpLister.GetActiveDirectoryIdentityProviders() {
if p.GetName() == upstreamName {
return nil, p, psession.ProviderTypeActiveDirectory, nil
}
}
}
return nil, nil, "", errors.New("provider not found")
}
// WriteAuthorizeError writes an authorization error as it should be returned by the authorization endpoint and other // WriteAuthorizeError writes an authorization error as it should be returned by the authorization endpoint and other
// similar endpoints that are the end of the downstream authcode flow. Errors responses are written in the usual fosite style. // similar endpoints that are the end of the downstream authcode flow. Errors responses are written in the usual fosite style.
func WriteAuthorizeError(r *http.Request, w http.ResponseWriter, oauthHelper fosite.OAuth2Provider, authorizeRequester fosite.AuthorizeRequester, err error, isBrowserless bool) { func WriteAuthorizeError(r *http.Request, w http.ResponseWriter, oauthHelper fosite.OAuth2Provider, authorizeRequester fosite.AuthorizeRequester, err error, isBrowserless bool) {

View File

@ -4,182 +4,67 @@
package provider package provider
import ( import (
"context"
"fmt" "fmt"
"net/url"
"sync" "sync"
"golang.org/x/oauth2" "go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"k8s.io/apimachinery/pkg/types"
"go.pinniped.dev/internal/authenticators"
"go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce"
) )
type RevocableTokenType string
// These strings correspond to the token types defined by https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
const (
RefreshTokenType RevocableTokenType = "refresh_token"
AccessTokenType RevocableTokenType = "access_token"
)
type UpstreamOIDCIdentityProviderI interface {
// GetName returns a name for this upstream provider, which will be used as a component of the path for the
// callback endpoint hosted by the Supervisor.
GetName() string
// GetClientID returns the OAuth client ID registered with the upstream provider to be used in the authorization code flow.
GetClientID() string
// GetResourceUID returns the Kubernetes resource ID
GetResourceUID() types.UID
// GetAuthorizationURL returns the Authorization Endpoint fetched from discovery.
GetAuthorizationURL() *url.URL
// HasUserInfoURL returns whether there is a non-empty value for userinfo_endpoint fetched from discovery.
HasUserInfoURL() bool
// GetScopes returns the scopes to request in authorization (authcode or password grant) flow.
GetScopes() []string
// GetUsernameClaim returns the ID Token username claim name. May return empty string, in which case we
// will use some reasonable defaults.
GetUsernameClaim() string
// GetGroupsClaim returns the ID Token groups claim name. May return empty string, in which case we won't
// try to read groups from the upstream provider.
GetGroupsClaim() string
// AllowsPasswordGrant returns true if a client should be allowed to use the resource owner password credentials grant
// flow with this upstream provider. When false, it should not be allowed.
AllowsPasswordGrant() bool
// GetAdditionalAuthcodeParams returns additional params to be sent on authcode requests.
GetAdditionalAuthcodeParams() map[string]string
// GetAdditionalClaimMappings returns additional claims to be mapped from the upstream ID token.
GetAdditionalClaimMappings() map[string]string
// PasswordCredentialsGrantAndValidateTokens performs upstream OIDC resource owner password credentials grant and
// token validation. Returns the validated raw tokens as well as the parsed claims of the ID token.
PasswordCredentialsGrantAndValidateTokens(ctx context.Context, username, password string) (*oidctypes.Token, error)
// ExchangeAuthcodeAndValidateTokens performs upstream OIDC authorization code exchange and token validation.
// Returns the validated raw tokens as well as the parsed claims of the ID token.
ExchangeAuthcodeAndValidateTokens(
ctx context.Context,
authcode string,
pkceCodeVerifier pkce.Code,
expectedIDTokenNonce nonce.Nonce,
redirectURI string,
) (*oidctypes.Token, error)
// PerformRefresh will call the provider's token endpoint to perform a refresh grant. The provider may or may not
// return a new ID or refresh token in the response. If it returns an ID token, then use ValidateToken to
// validate the ID token.
PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error)
// RevokeToken will attempt to revoke the given token, if the provider has a revocation endpoint.
// It may return an error wrapped by a RetryableRevocationError, which is an error indicating that it may
// be worth trying to revoke the same token again later. Any other error returned should be assumed to
// represent an error such that it is not worth retrying revocation later, even though revocation failed.
RevokeToken(ctx context.Context, token string, tokenType RevocableTokenType) error
// ValidateTokenAndMergeWithUserInfo will validate the ID token. It will also merge the claims from the userinfo endpoint response
// into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated
// tokens, or an error.
ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool, requireUserInfo bool) (*oidctypes.Token, error)
}
type UpstreamLDAPIdentityProviderI interface {
// GetName returns a name for this upstream provider.
GetName() string
// GetURL returns a URL which uniquely identifies this LDAP provider, e.g. "ldaps://host.example.com:1234".
// This URL is not used for connecting to the provider, but rather is used for creating a globally unique user
// identifier by being combined with the user's UID, since user UIDs are only unique within one provider.
GetURL() *url.URL
// GetResourceUID returns the Kubernetes resource ID
GetResourceUID() types.UID
// UserAuthenticator adds an interface method for performing user authentication against the upstream LDAP provider.
authenticators.UserAuthenticator
// PerformRefresh performs a refresh against the upstream LDAP identity provider
PerformRefresh(ctx context.Context, storedRefreshAttributes RefreshAttributes) (groups []string, err error)
}
// RefreshAttributes contains information about the user from the original login request
// and previous refreshes.
type RefreshAttributes struct {
Username string
Subject string
DN string
Groups []string
AdditionalAttributes map[string]string
GrantedScopes []string
}
type DynamicUpstreamIDPProvider interface { type DynamicUpstreamIDPProvider interface {
SetOIDCIdentityProviders(oidcIDPs []UpstreamOIDCIdentityProviderI) SetOIDCIdentityProviders(oidcIDPs []upstreamprovider.UpstreamOIDCIdentityProviderI)
GetOIDCIdentityProviders() []UpstreamOIDCIdentityProviderI GetOIDCIdentityProviders() []upstreamprovider.UpstreamOIDCIdentityProviderI
SetLDAPIdentityProviders(ldapIDPs []UpstreamLDAPIdentityProviderI) SetLDAPIdentityProviders(ldapIDPs []upstreamprovider.UpstreamLDAPIdentityProviderI)
GetLDAPIdentityProviders() []UpstreamLDAPIdentityProviderI GetLDAPIdentityProviders() []upstreamprovider.UpstreamLDAPIdentityProviderI
SetActiveDirectoryIdentityProviders(adIDPs []UpstreamLDAPIdentityProviderI) SetActiveDirectoryIdentityProviders(adIDPs []upstreamprovider.UpstreamLDAPIdentityProviderI)
GetActiveDirectoryIdentityProviders() []UpstreamLDAPIdentityProviderI GetActiveDirectoryIdentityProviders() []upstreamprovider.UpstreamLDAPIdentityProviderI
} }
type dynamicUpstreamIDPProvider struct { type dynamicUpstreamIDPProvider struct {
oidcUpstreams []UpstreamOIDCIdentityProviderI oidcUpstreams []upstreamprovider.UpstreamOIDCIdentityProviderI
ldapUpstreams []UpstreamLDAPIdentityProviderI ldapUpstreams []upstreamprovider.UpstreamLDAPIdentityProviderI
activeDirectoryUpstreams []UpstreamLDAPIdentityProviderI activeDirectoryUpstreams []upstreamprovider.UpstreamLDAPIdentityProviderI
mutex sync.RWMutex mutex sync.RWMutex
} }
func NewDynamicUpstreamIDPProvider() DynamicUpstreamIDPProvider { func NewDynamicUpstreamIDPProvider() DynamicUpstreamIDPProvider {
return &dynamicUpstreamIDPProvider{ return &dynamicUpstreamIDPProvider{
oidcUpstreams: []UpstreamOIDCIdentityProviderI{}, oidcUpstreams: []upstreamprovider.UpstreamOIDCIdentityProviderI{},
ldapUpstreams: []UpstreamLDAPIdentityProviderI{}, ldapUpstreams: []upstreamprovider.UpstreamLDAPIdentityProviderI{},
activeDirectoryUpstreams: []UpstreamLDAPIdentityProviderI{}, activeDirectoryUpstreams: []upstreamprovider.UpstreamLDAPIdentityProviderI{},
} }
} }
func (p *dynamicUpstreamIDPProvider) SetOIDCIdentityProviders(oidcIDPs []UpstreamOIDCIdentityProviderI) { func (p *dynamicUpstreamIDPProvider) SetOIDCIdentityProviders(oidcIDPs []upstreamprovider.UpstreamOIDCIdentityProviderI) {
p.mutex.Lock() // acquire a write lock p.mutex.Lock() // acquire a write lock
defer p.mutex.Unlock() defer p.mutex.Unlock()
p.oidcUpstreams = oidcIDPs p.oidcUpstreams = oidcIDPs
} }
func (p *dynamicUpstreamIDPProvider) GetOIDCIdentityProviders() []UpstreamOIDCIdentityProviderI { func (p *dynamicUpstreamIDPProvider) GetOIDCIdentityProviders() []upstreamprovider.UpstreamOIDCIdentityProviderI {
p.mutex.RLock() // acquire a read lock p.mutex.RLock() // acquire a read lock
defer p.mutex.RUnlock() defer p.mutex.RUnlock()
return p.oidcUpstreams return p.oidcUpstreams
} }
func (p *dynamicUpstreamIDPProvider) SetLDAPIdentityProviders(ldapIDPs []UpstreamLDAPIdentityProviderI) { func (p *dynamicUpstreamIDPProvider) SetLDAPIdentityProviders(ldapIDPs []upstreamprovider.UpstreamLDAPIdentityProviderI) {
p.mutex.Lock() // acquire a write lock p.mutex.Lock() // acquire a write lock
defer p.mutex.Unlock() defer p.mutex.Unlock()
p.ldapUpstreams = ldapIDPs p.ldapUpstreams = ldapIDPs
} }
func (p *dynamicUpstreamIDPProvider) GetLDAPIdentityProviders() []UpstreamLDAPIdentityProviderI { func (p *dynamicUpstreamIDPProvider) GetLDAPIdentityProviders() []upstreamprovider.UpstreamLDAPIdentityProviderI {
p.mutex.RLock() // acquire a read lock p.mutex.RLock() // acquire a read lock
defer p.mutex.RUnlock() defer p.mutex.RUnlock()
return p.ldapUpstreams return p.ldapUpstreams
} }
func (p *dynamicUpstreamIDPProvider) SetActiveDirectoryIdentityProviders(adIDPs []UpstreamLDAPIdentityProviderI) { func (p *dynamicUpstreamIDPProvider) SetActiveDirectoryIdentityProviders(adIDPs []upstreamprovider.UpstreamLDAPIdentityProviderI) {
p.mutex.Lock() // acquire a write lock p.mutex.Lock() // acquire a write lock
defer p.mutex.Unlock() defer p.mutex.Unlock()
p.activeDirectoryUpstreams = adIDPs p.activeDirectoryUpstreams = adIDPs
} }
func (p *dynamicUpstreamIDPProvider) GetActiveDirectoryIdentityProviders() []UpstreamLDAPIdentityProviderI { func (p *dynamicUpstreamIDPProvider) GetActiveDirectoryIdentityProviders() []upstreamprovider.UpstreamLDAPIdentityProviderI {
p.mutex.RLock() // acquire a read lock p.mutex.RLock() // acquire a read lock
defer p.mutex.RUnlock() defer p.mutex.RUnlock()
return p.activeDirectoryUpstreams return p.activeDirectoryUpstreams

View File

@ -0,0 +1,232 @@
// Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package provider
import (
"fmt"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/sets"
"go.pinniped.dev/internal/idtransform"
"go.pinniped.dev/internal/oidc/idplister"
"go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/psession"
)
// FederationDomainIdentityProvider represents an identity provider as configured in a FederationDomain's spec.
// All the fields are required and must be non-zero values. Note that this might be a reference to an IDP
// which is not currently loaded into the cache of available IDPs, e.g. due to the IDP's CR having validation errors.
type FederationDomainIdentityProvider struct {
DisplayName string
UID types.UID
Transforms *idtransform.TransformationPipeline
}
// FederationDomainResolvedOIDCIdentityProvider represents a FederationDomainIdentityProvider which has
// been resolved dynamically based on the currently loaded IDP CRs to include the provider.UpstreamOIDCIdentityProviderI
// and other metadata about the provider.
type FederationDomainResolvedOIDCIdentityProvider struct {
DisplayName string
Provider upstreamprovider.UpstreamOIDCIdentityProviderI
SessionProviderType psession.ProviderType
Transforms *idtransform.TransformationPipeline
}
// FederationDomainResolvedLDAPIdentityProvider represents a FederationDomainIdentityProvider which has
// been resolved dynamically based on the currently loaded IDP CRs to include the provider.UpstreamLDAPIdentityProviderI
// and other metadata about the provider.
type FederationDomainResolvedLDAPIdentityProvider struct {
DisplayName string
Provider upstreamprovider.UpstreamLDAPIdentityProviderI
SessionProviderType psession.ProviderType
Transforms *idtransform.TransformationPipeline
}
type FederationDomainIdentityProvidersFinderI interface {
FindDefaultIDP() (
*FederationDomainResolvedOIDCIdentityProvider,
*FederationDomainResolvedLDAPIdentityProvider,
error,
)
FindUpstreamIDPByDisplayName(upstreamIDPDisplayName string) (
*FederationDomainResolvedOIDCIdentityProvider,
*FederationDomainResolvedLDAPIdentityProvider,
error,
)
}
type FederationDomainIdentityProvidersListerI interface {
GetOIDCIdentityProviders() []*FederationDomainResolvedOIDCIdentityProvider
GetLDAPIdentityProviders() []*FederationDomainResolvedLDAPIdentityProvider
GetActiveDirectoryIdentityProviders() []*FederationDomainResolvedLDAPIdentityProvider
}
// FederationDomainIdentityProvidersLister wraps an UpstreamIdentityProvidersLister. The lister which is being
// wrapped should contain all valid upstream providers that are currently defined in the Supervisor.
// FederationDomainIdentityProvidersLister provides a lookup method which only looks up IDPs within those which
// have allowed resource IDs, and also uses display names (name aliases) instead of the actual resource names to do the
// lookups. It also provides list methods which only list the allowed identity providers (to be used by the IDP
// discovery endpoint, for example).
type FederationDomainIdentityProvidersLister struct {
wrappedLister idplister.UpstreamIdentityProvidersLister
configuredIdentityProviders []*FederationDomainIdentityProvider
defaultIdentityProvider *FederationDomainIdentityProvider
idpDisplayNamesToResourceUIDsMap map[string]types.UID
allowedIDPResourceUIDs sets.Set[types.UID]
}
// NewFederationDomainUpstreamIdentityProvidersLister returns a new FederationDomainIdentityProvidersLister
// which only lists those IDPs allowed by its parameter. Every FederationDomainIdentityProvider in the
// federationDomainIssuer parameter's IdentityProviders() list must have a unique DisplayName.
// Note that a single underlying IDP UID may be used by multiple FederationDomainIdentityProvider in the parameter.
// The wrapped lister should contain all valid upstream providers that are defined in the Supervisor, and is expected to
// be thread-safe and to change its contents over time. The FederationDomainIdentityProvidersLister will filter out the
// ones that don't apply to this federation domain.
func NewFederationDomainUpstreamIdentityProvidersLister(
federationDomainIssuer *FederationDomainIssuer,
wrappedLister idplister.UpstreamIdentityProvidersLister,
) *FederationDomainIdentityProvidersLister {
// Create a copy of the input slice so we won't need to worry about the caller accidentally changing it.
copyOfFederationDomainIdentityProviders := []*FederationDomainIdentityProvider{}
// Create a map and a set for quick lookups of the same data that was passed in via the
// federationDomainIssuer parameter.
allowedResourceUIDs := sets.New[types.UID]()
idpDisplayNamesToResourceUIDsMap := map[string]types.UID{}
for _, idp := range federationDomainIssuer.IdentityProviders() {
allowedResourceUIDs.Insert(idp.UID)
idpDisplayNamesToResourceUIDsMap[idp.DisplayName] = idp.UID
shallowCopyOfIDP := *idp
copyOfFederationDomainIdentityProviders = append(copyOfFederationDomainIdentityProviders, &shallowCopyOfIDP)
}
return &FederationDomainIdentityProvidersLister{
wrappedLister: wrappedLister,
configuredIdentityProviders: copyOfFederationDomainIdentityProviders,
defaultIdentityProvider: federationDomainIssuer.DefaultIdentityProvider(),
idpDisplayNamesToResourceUIDsMap: idpDisplayNamesToResourceUIDsMap,
allowedIDPResourceUIDs: allowedResourceUIDs,
}
}
// FindUpstreamIDPByDisplayName selects either an OIDC, LDAP, or ActiveDirectory IDP, or returns an error.
// It only considers the allowed IDPs while doing the lookup by display name.
// Note that ActiveDirectory and LDAP IDPs both return the same type, but with different SessionProviderType values.
func (u *FederationDomainIdentityProvidersLister) FindUpstreamIDPByDisplayName(upstreamIDPDisplayName string) (
*FederationDomainResolvedOIDCIdentityProvider,
*FederationDomainResolvedLDAPIdentityProvider,
error,
) {
// Given a display name, look up the identity provider's UID for that display name.
idpUIDForDisplayName, ok := u.idpDisplayNamesToResourceUIDsMap[upstreamIDPDisplayName]
if !ok {
return nil, nil, fmt.Errorf("identity provider not found: %q", upstreamIDPDisplayName)
}
// Find the IDP with that UID. It could be any type, so look at all types to find it.
for _, p := range u.GetOIDCIdentityProviders() {
if p.Provider.GetResourceUID() == idpUIDForDisplayName {
return p, nil, nil
}
}
for _, p := range u.GetLDAPIdentityProviders() {
if p.Provider.GetResourceUID() == idpUIDForDisplayName {
return nil, p, nil
}
}
for _, p := range u.GetActiveDirectoryIdentityProviders() {
if p.Provider.GetResourceUID() == idpUIDForDisplayName {
return nil, p, nil
}
}
return nil, nil, fmt.Errorf("identity provider not found: %q", upstreamIDPDisplayName)
}
// FindDefaultIDP works like FindUpstreamIDPByDisplayName, but finds the default IDP instead of finding by name.
// If there is no default IDP for this federation domain, then FindDefaultIDP will return an error.
// This can be used to handle the backwards compatibility mode where an authorization request could be made
// without specifying an IDP name, and there are no IDPs explicitly specified on the FederationDomain, and there
// is exactly one IDP CR defined in the Supervisor namespace.
func (u *FederationDomainIdentityProvidersLister) FindDefaultIDP() (
*FederationDomainResolvedOIDCIdentityProvider,
*FederationDomainResolvedLDAPIdentityProvider,
error,
) {
if u.defaultIdentityProvider == nil {
return nil, nil, fmt.Errorf("identity provider not found: this federation domain does not have a default identity provider")
}
return u.FindUpstreamIDPByDisplayName(u.defaultIdentityProvider.DisplayName)
}
// GetOIDCIdentityProviders lists only the OIDC providers for this FederationDomain.
func (u *FederationDomainIdentityProvidersLister) GetOIDCIdentityProviders() []*FederationDomainResolvedOIDCIdentityProvider {
// Get the cached providers once at the start in case they change during the rest of this function.
cachedProviders := u.wrappedLister.GetOIDCIdentityProviders()
providers := []*FederationDomainResolvedOIDCIdentityProvider{}
// Every configured identityProvider on the FederationDomain uses an objetRef to an underlying IDP CR that might
// be available as a provider in the wrapped cache. For each configured identityProvider/displayName...
for _, idp := range u.configuredIdentityProviders {
// Check if the IDP used by that displayName is in the cached available OIDC providers.
for _, p := range cachedProviders {
if idp.UID == p.GetResourceUID() {
// Found it, so append it to the result.
providers = append(providers, &FederationDomainResolvedOIDCIdentityProvider{
DisplayName: idp.DisplayName,
Provider: p,
SessionProviderType: psession.ProviderTypeOIDC,
Transforms: idp.Transforms,
})
}
}
}
return providers
}
// GetLDAPIdentityProviders lists only the LDAP providers for this FederationDomain.
func (u *FederationDomainIdentityProvidersLister) GetLDAPIdentityProviders() []*FederationDomainResolvedLDAPIdentityProvider {
// Get the cached providers once at the start in case they change during the rest of this function.
cachedProviders := u.wrappedLister.GetLDAPIdentityProviders()
providers := []*FederationDomainResolvedLDAPIdentityProvider{}
// Every configured identityProvider on the FederationDomain uses an objetRef to an underlying IDP CR that might
// be available as a provider in the wrapped cache. For each configured identityProvider/displayName...
for _, idp := range u.configuredIdentityProviders {
// Check if the IDP used by that displayName is in the cached available LDAP providers.
for _, p := range cachedProviders {
if idp.UID == p.GetResourceUID() {
// Found it, so append it to the result.
providers = append(providers, &FederationDomainResolvedLDAPIdentityProvider{
DisplayName: idp.DisplayName,
Provider: p,
SessionProviderType: psession.ProviderTypeLDAP,
Transforms: idp.Transforms,
})
}
}
}
return providers
}
// GetActiveDirectoryIdentityProviders lists only the ActiveDirectory providers for this FederationDomain.
func (u *FederationDomainIdentityProvidersLister) GetActiveDirectoryIdentityProviders() []*FederationDomainResolvedLDAPIdentityProvider {
// Get the cached providers once at the start in case they change during the rest of this function.
cachedProviders := u.wrappedLister.GetActiveDirectoryIdentityProviders()
providers := []*FederationDomainResolvedLDAPIdentityProvider{}
// Every configured identityProvider on the FederationDomain uses an objetRef to an underlying IDP CR that might
// be available as a provider in the wrapped cache. For each configured identityProvider/displayName...
for _, idp := range u.configuredIdentityProviders {
// Check if the IDP used by that displayName is in the cached available ActiveDirectory providers.
for _, p := range cachedProviders {
if idp.UID == p.GetResourceUID() {
// Found it, so append it to the result.
providers = append(providers, &FederationDomainResolvedLDAPIdentityProvider{
DisplayName: idp.DisplayName,
Provider: p,
SessionProviderType: psession.ProviderTypeActiveDirectory,
Transforms: idp.Transforms,
})
}
}
}
return providers
}

View File

@ -17,20 +17,42 @@ type FederationDomainIssuer struct {
issuer string issuer string
issuerHost string issuerHost string
issuerPath string issuerPath string
// identityProviders should be used when they are explicitly specified in the FederationDomain's spec.
identityProviders []*FederationDomainIdentityProvider
// defaultIdentityProvider should be used only for the backwards compatibility mode where identity providers
// are not explicitly specified in the FederationDomain's spec, and there is exactly one IDP CR defined in the
// Supervisor's namespace.
defaultIdentityProvider *FederationDomainIdentityProvider
} }
// NewFederationDomainIssuer returns a FederationDomainIssuer. // NewFederationDomainIssuer returns a FederationDomainIssuer.
// Performs validation, and returns any error from validation. // Performs validation, and returns any error from validation.
func NewFederationDomainIssuer(issuer string) (*FederationDomainIssuer, error) { func NewFederationDomainIssuer(
p := FederationDomainIssuer{issuer: issuer} issuer string,
err := p.validate() identityProviders []*FederationDomainIdentityProvider,
) (*FederationDomainIssuer, error) {
p := FederationDomainIssuer{issuer: issuer, identityProviders: identityProviders}
err := p.validateURL()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &p, nil return &p, nil
} }
func (p *FederationDomainIssuer) validate() error { func NewFederationDomainIssuerWithDefaultIDP(
issuer string,
defaultIdentityProvider *FederationDomainIdentityProvider,
) (*FederationDomainIssuer, error) {
fdi, err := NewFederationDomainIssuer(issuer, []*FederationDomainIdentityProvider{defaultIdentityProvider})
if err != nil {
return nil, err
}
fdi.defaultIdentityProvider = defaultIdentityProvider
return fdi, nil
}
func (p *FederationDomainIssuer) validateURL() error {
if p.issuer == "" { if p.issuer == "" {
return constable.Error("federation domain must have an issuer") return constable.Error("federation domain must have an issuer")
} }
@ -84,3 +106,13 @@ func (p *FederationDomainIssuer) IssuerHost() string {
func (p *FederationDomainIssuer) IssuerPath() string { func (p *FederationDomainIssuer) IssuerPath() string {
return p.issuerPath return p.issuerPath
} }
// IdentityProviders returns the IdentityProviders.
func (p *FederationDomainIssuer) IdentityProviders() []*FederationDomainIdentityProvider {
return p.identityProviders
}
// DefaultIdentityProvider will return nil when there is no default.
func (p *FederationDomainIssuer) DefaultIdentityProvider() *FederationDomainIdentityProvider {
return p.defaultIdentityProvider
}

View File

@ -19,6 +19,7 @@ import (
"go.pinniped.dev/internal/oidc/discovery" "go.pinniped.dev/internal/oidc/discovery"
"go.pinniped.dev/internal/oidc/dynamiccodec" "go.pinniped.dev/internal/oidc/dynamiccodec"
"go.pinniped.dev/internal/oidc/idpdiscovery" "go.pinniped.dev/internal/oidc/idpdiscovery"
"go.pinniped.dev/internal/oidc/idplister"
"go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/jwks"
"go.pinniped.dev/internal/oidc/login" "go.pinniped.dev/internal/oidc/login"
"go.pinniped.dev/internal/oidc/oidcclientvalidator" "go.pinniped.dev/internal/oidc/oidcclientvalidator"
@ -36,11 +37,11 @@ import (
type Manager struct { type Manager struct {
mu sync.RWMutex mu sync.RWMutex
providers []*provider.FederationDomainIssuer providers []*provider.FederationDomainIssuer
providerHandlers map[string]http.Handler // map of all routes for all providers providerHandlers map[string]http.Handler // map of all routes for all providers
nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request nextHandler http.Handler // the next handler in a chain, called when this manager didn't know how to handle a request
dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data dynamicJWKSProvider jwks.DynamicJWKSProvider // in-memory cache of per-issuer JWKS data
upstreamIDPs oidc.UpstreamIdentityProvidersLister // in-memory cache of upstream IDPs upstreamIDPs idplister.UpstreamIdentityProvidersLister // in-memory cache of upstream IDPs
secretCache *secret.Cache // in-memory cache of cryptographic material secretCache *secret.Cache // in-memory cache of cryptographic material
secretsClient corev1client.SecretInterface secretsClient corev1client.SecretInterface
oidcClientsClient v1alpha1.OIDCClientInterface oidcClientsClient v1alpha1.OIDCClientInterface
} }
@ -52,7 +53,7 @@ type Manager struct {
func NewManager( func NewManager(
nextHandler http.Handler, nextHandler http.Handler,
dynamicJWKSProvider jwks.DynamicJWKSProvider, dynamicJWKSProvider jwks.DynamicJWKSProvider,
upstreamIDPs oidc.UpstreamIdentityProvidersLister, upstreamIDPs idplister.UpstreamIdentityProvidersLister,
secretCache *secret.Cache, secretCache *secret.Cache,
secretsClient corev1client.SecretInterface, secretsClient corev1client.SecretInterface,
oidcClientsClient v1alpha1.OIDCClientInterface, oidcClientsClient v1alpha1.OIDCClientInterface,
@ -83,17 +84,17 @@ func (m *Manager) SetProviders(federationDomains ...*provider.FederationDomainIs
m.providers = federationDomains m.providers = federationDomains
m.providerHandlers = make(map[string]http.Handler) m.providerHandlers = make(map[string]http.Handler)
var csrfCookieEncoder = dynamiccodec.New( csrfCookieEncoder := dynamiccodec.New(
oidc.CSRFCookieLifespan, oidc.CSRFCookieLifespan,
m.secretCache.GetCSRFCookieEncoderHashKey, m.secretCache.GetCSRFCookieEncoderHashKey,
func() []byte { return nil }, func() []byte { return nil },
) )
for _, incomingProvider := range federationDomains { for _, incomingFederationDomain := range federationDomains {
issuer := incomingProvider.Issuer() issuerURL := incomingFederationDomain.Issuer()
issuerHostWithPath := strings.ToLower(incomingProvider.IssuerHost()) + "/" + incomingProvider.IssuerPath() issuerHostWithPath := strings.ToLower(incomingFederationDomain.IssuerHost()) + "/" + incomingFederationDomain.IssuerPath()
tokenHMACKeyGetter := wrapGetter(incomingProvider.Issuer(), m.secretCache.GetTokenHMACKey) tokenHMACKeyGetter := wrapGetter(incomingFederationDomain.Issuer(), m.secretCache.GetTokenHMACKey)
timeoutsConfiguration := oidc.DefaultOIDCTimeoutsConfiguration() timeoutsConfiguration := oidc.DefaultOIDCTimeoutsConfiguration()
@ -101,7 +102,7 @@ func (m *Manager) SetProviders(federationDomains ...*provider.FederationDomainIs
// the upstream callback endpoint is called later. // the upstream callback endpoint is called later.
oauthHelperWithNullStorage := oidc.FositeOauth2Helper( oauthHelperWithNullStorage := oidc.FositeOauth2Helper(
oidc.NewNullStorage(m.secretsClient, m.oidcClientsClient, oidcclientvalidator.DefaultMinBcryptCost), oidc.NewNullStorage(m.secretsClient, m.oidcClientsClient, oidcclientvalidator.DefaultMinBcryptCost),
issuer, issuerURL,
tokenHMACKeyGetter, tokenHMACKeyGetter,
nil, nil,
timeoutsConfiguration, timeoutsConfiguration,
@ -110,27 +111,29 @@ func (m *Manager) SetProviders(federationDomains ...*provider.FederationDomainIs
// For all the other endpoints, make another oauth helper with exactly the same settings except use real storage. // For all the other endpoints, make another oauth helper with exactly the same settings except use real storage.
oauthHelperWithKubeStorage := oidc.FositeOauth2Helper( oauthHelperWithKubeStorage := oidc.FositeOauth2Helper(
oidc.NewKubeStorage(m.secretsClient, m.oidcClientsClient, timeoutsConfiguration, oidcclientvalidator.DefaultMinBcryptCost), oidc.NewKubeStorage(m.secretsClient, m.oidcClientsClient, timeoutsConfiguration, oidcclientvalidator.DefaultMinBcryptCost),
issuer, issuerURL,
tokenHMACKeyGetter, tokenHMACKeyGetter,
m.dynamicJWKSProvider, m.dynamicJWKSProvider,
timeoutsConfiguration, timeoutsConfiguration,
) )
var upstreamStateEncoder = dynamiccodec.New( upstreamStateEncoder := dynamiccodec.New(
timeoutsConfiguration.UpstreamStateParamLifespan, timeoutsConfiguration.UpstreamStateParamLifespan,
wrapGetter(incomingProvider.Issuer(), m.secretCache.GetStateEncoderHashKey), wrapGetter(incomingFederationDomain.Issuer(), m.secretCache.GetStateEncoderHashKey),
wrapGetter(incomingProvider.Issuer(), m.secretCache.GetStateEncoderBlockKey), wrapGetter(incomingFederationDomain.Issuer(), m.secretCache.GetStateEncoderBlockKey),
) )
m.providerHandlers[(issuerHostWithPath + oidc.WellKnownEndpointPath)] = discovery.NewHandler(issuer) idpLister := provider.NewFederationDomainUpstreamIdentityProvidersLister(incomingFederationDomain, m.upstreamIDPs)
m.providerHandlers[(issuerHostWithPath + oidc.JWKSEndpointPath)] = jwks.NewHandler(issuer, m.dynamicJWKSProvider) m.providerHandlers[(issuerHostWithPath + oidc.WellKnownEndpointPath)] = discovery.NewHandler(issuerURL)
m.providerHandlers[(issuerHostWithPath + oidc.PinnipedIDPsPathV1Alpha1)] = idpdiscovery.NewHandler(m.upstreamIDPs) m.providerHandlers[(issuerHostWithPath + oidc.JWKSEndpointPath)] = jwks.NewHandler(issuerURL, m.dynamicJWKSProvider)
m.providerHandlers[(issuerHostWithPath + oidc.PinnipedIDPsPathV1Alpha1)] = idpdiscovery.NewHandler(idpLister)
m.providerHandlers[(issuerHostWithPath + oidc.AuthorizationEndpointPath)] = auth.NewHandler( m.providerHandlers[(issuerHostWithPath + oidc.AuthorizationEndpointPath)] = auth.NewHandler(
issuer, issuerURL,
m.upstreamIDPs, idpLister,
oauthHelperWithNullStorage, oauthHelperWithNullStorage,
oauthHelperWithKubeStorage, oauthHelperWithKubeStorage,
csrftoken.Generate, csrftoken.Generate,
@ -141,26 +144,26 @@ func (m *Manager) SetProviders(federationDomains ...*provider.FederationDomainIs
) )
m.providerHandlers[(issuerHostWithPath + oidc.CallbackEndpointPath)] = callback.NewHandler( m.providerHandlers[(issuerHostWithPath + oidc.CallbackEndpointPath)] = callback.NewHandler(
m.upstreamIDPs, idpLister,
oauthHelperWithKubeStorage, oauthHelperWithKubeStorage,
upstreamStateEncoder, upstreamStateEncoder,
csrfCookieEncoder, csrfCookieEncoder,
issuer+oidc.CallbackEndpointPath, issuerURL+oidc.CallbackEndpointPath,
) )
m.providerHandlers[(issuerHostWithPath + oidc.TokenEndpointPath)] = token.NewHandler( m.providerHandlers[(issuerHostWithPath + oidc.TokenEndpointPath)] = token.NewHandler(
m.upstreamIDPs, idpLister,
oauthHelperWithKubeStorage, oauthHelperWithKubeStorage,
) )
m.providerHandlers[(issuerHostWithPath + oidc.PinnipedLoginPath)] = login.NewHandler( m.providerHandlers[(issuerHostWithPath + oidc.PinnipedLoginPath)] = login.NewHandler(
upstreamStateEncoder, upstreamStateEncoder,
csrfCookieEncoder, csrfCookieEncoder,
login.NewGetHandler(incomingProvider.IssuerPath()+oidc.PinnipedLoginPath), login.NewGetHandler(incomingFederationDomain.IssuerPath()+oidc.PinnipedLoginPath),
login.NewPostHandler(issuer, m.upstreamIDPs, oauthHelperWithKubeStorage), login.NewPostHandler(issuerURL, idpLister, oauthHelperWithKubeStorage),
) )
plog.Debug("oidc provider manager added or updated issuer", "issuer", issuer) plog.Debug("oidc provider manager added or updated issuer", "issuer", issuerURL)
} }
} }

View File

@ -0,0 +1,126 @@
// Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package upstreamprovider
import (
"context"
"net/url"
"golang.org/x/oauth2"
"k8s.io/apimachinery/pkg/types"
"go.pinniped.dev/internal/authenticators"
"go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce"
)
type RevocableTokenType string
// These strings correspond to the token types defined by https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
const (
RefreshTokenType RevocableTokenType = "refresh_token"
AccessTokenType RevocableTokenType = "access_token"
)
// RefreshAttributes contains information about the user from the original login request
// and previous refreshes.
type RefreshAttributes struct {
Username string
Subject string
DN string
Groups []string
AdditionalAttributes map[string]string
GrantedScopes []string
}
type UpstreamOIDCIdentityProviderI interface {
// GetName returns a name for this upstream provider. The controller watching the OIDCIdentityProviders will
// set this to be the Name of the CR from its metadata. Note that this is different from the DisplayName configured
// in each FederationDomain that uses this provider, so this name is for internal use only, not for interacting
// with clients. Clients should not expect to see this name or send this name.
GetName() string
// GetClientID returns the OAuth client ID registered with the upstream provider to be used in the authorization code flow.
GetClientID() string
// GetResourceUID returns the Kubernetes resource ID
GetResourceUID() types.UID
// GetAuthorizationURL returns the Authorization Endpoint fetched from discovery.
GetAuthorizationURL() *url.URL
// HasUserInfoURL returns whether there is a non-empty value for userinfo_endpoint fetched from discovery.
HasUserInfoURL() bool
// GetScopes returns the scopes to request in authorization (authcode or password grant) flow.
GetScopes() []string
// GetUsernameClaim returns the ID Token username claim name. May return empty string, in which case we
// will use some reasonable defaults.
GetUsernameClaim() string
// GetGroupsClaim returns the ID Token groups claim name. May return empty string, in which case we won't
// try to read groups from the upstream provider.
GetGroupsClaim() string
// AllowsPasswordGrant returns true if a client should be allowed to use the resource owner password credentials grant
// flow with this upstream provider. When false, it should not be allowed.
AllowsPasswordGrant() bool
// GetAdditionalAuthcodeParams returns additional params to be sent on authcode requests.
GetAdditionalAuthcodeParams() map[string]string
// GetAdditionalClaimMappings returns additional claims to be mapped from the upstream ID token.
GetAdditionalClaimMappings() map[string]string
// PasswordCredentialsGrantAndValidateTokens performs upstream OIDC resource owner password credentials grant and
// token validation. Returns the validated raw tokens as well as the parsed claims of the ID token.
PasswordCredentialsGrantAndValidateTokens(ctx context.Context, username, password string) (*oidctypes.Token, error)
// ExchangeAuthcodeAndValidateTokens performs upstream OIDC authorization code exchange and token validation.
// Returns the validated raw tokens as well as the parsed claims of the ID token.
ExchangeAuthcodeAndValidateTokens(
ctx context.Context,
authcode string,
pkceCodeVerifier pkce.Code,
expectedIDTokenNonce nonce.Nonce,
redirectURI string,
) (*oidctypes.Token, error)
// PerformRefresh will call the provider's token endpoint to perform a refresh grant. The provider may or may not
// return a new ID or refresh token in the response. If it returns an ID token, then use ValidateToken to
// validate the ID token.
PerformRefresh(ctx context.Context, refreshToken string) (*oauth2.Token, error)
// RevokeToken will attempt to revoke the given token, if the provider has a revocation endpoint.
// It may return an error wrapped by a RetryableRevocationError, which is an error indicating that it may
// be worth trying to revoke the same token again later. Any other error returned should be assumed to
// represent an error such that it is not worth retrying revocation later, even though revocation failed.
RevokeToken(ctx context.Context, token string, tokenType RevocableTokenType) error
// ValidateTokenAndMergeWithUserInfo will validate the ID token. It will also merge the claims from the userinfo endpoint response
// into the ID token's claims, if the provider offers the userinfo endpoint. It returns the validated/updated
// tokens, or an error.
ValidateTokenAndMergeWithUserInfo(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce, requireIDToken bool, requireUserInfo bool) (*oidctypes.Token, error)
}
type UpstreamLDAPIdentityProviderI interface {
// GetName returns a name for this upstream provider.
GetName() string
// GetURL returns a URL which uniquely identifies this LDAP provider, e.g. "ldaps://host.example.com:1234".
// This URL is not used for connecting to the provider, but rather is used for creating a globally unique user
// identifier by being combined with the user's UID, since user UIDs are only unique within one provider.
GetURL() *url.URL
// GetResourceUID returns the Kubernetes resource ID
GetResourceUID() types.UID
// UserAuthenticator adds an interface method for performing user authentication against the upstream LDAP provider.
authenticators.UserAuthenticator
// PerformRefresh performs a refresh against the upstream LDAP identity provider
PerformRefresh(ctx context.Context, storedRefreshAttributes RefreshAttributes) (groups []string, err error)
}

View File

@ -1,4 +1,4 @@
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// Package token provides a handler for the OIDC token endpoint. // Package token provides a handler for the OIDC token endpoint.
@ -19,15 +19,17 @@ import (
oidcapi "go.pinniped.dev/generated/latest/apis/supervisor/oidc" oidcapi "go.pinniped.dev/generated/latest/apis/supervisor/oidc"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/idtransform"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/downstreamsession" "go.pinniped.dev/internal/oidc/downstreamsession"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/psession"
) )
func NewHandler( func NewHandler(
idpLister oidc.UpstreamIdentityProvidersLister, idpLister provider.FederationDomainIdentityProvidersListerI,
oauthHelper fosite.OAuth2Provider, oauthHelper fosite.OAuth2Provider,
) http.Handler { ) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
@ -95,7 +97,7 @@ func errUpstreamRefreshError() *fosite.RFC6749Error {
} }
} }
func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, providerCache oidc.UpstreamIdentityProvidersLister) error { func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester, idpLister provider.FederationDomainIdentityProvidersListerI) error {
session := accessRequest.GetSession().(*psession.PinnipedSession) session := accessRequest.GetSession().(*psession.PinnipedSession)
customSessionData := session.Custom customSessionData := session.Custom
@ -113,11 +115,11 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
switch customSessionData.ProviderType { switch customSessionData.ProviderType {
case psession.ProviderTypeOIDC: case psession.ProviderTypeOIDC:
return upstreamOIDCRefresh(ctx, session, providerCache, grantedScopes, clientID) return upstreamOIDCRefresh(ctx, session, idpLister, grantedScopes, clientID)
case psession.ProviderTypeLDAP: case psession.ProviderTypeLDAP:
return upstreamLDAPRefresh(ctx, providerCache, session, grantedScopes, clientID) return upstreamLDAPRefresh(ctx, idpLister, session, grantedScopes, clientID)
case psession.ProviderTypeActiveDirectory: case psession.ProviderTypeActiveDirectory:
return upstreamLDAPRefresh(ctx, providerCache, session, grantedScopes, clientID) return upstreamLDAPRefresh(ctx, idpLister, session, grantedScopes, clientID)
default: default:
return errorsx.WithStack(errMissingUpstreamSessionInternalError()) return errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
@ -126,7 +128,7 @@ func upstreamRefresh(ctx context.Context, accessRequest fosite.AccessRequester,
func upstreamOIDCRefresh( func upstreamOIDCRefresh(
ctx context.Context, ctx context.Context,
session *psession.PinnipedSession, session *psession.PinnipedSession,
providerCache oidc.UpstreamIdentityProvidersLister, idpLister provider.FederationDomainIdentityProvidersListerI,
grantedScopes []string, grantedScopes []string,
clientID string, clientID string,
) error { ) error {
@ -143,7 +145,7 @@ func upstreamOIDCRefresh(
return errorsx.WithStack(errMissingUpstreamSessionInternalError()) return errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
p, err := findOIDCProviderByNameAndValidateUID(s, providerCache) p, err := findOIDCProviderByNameAndValidateUID(s, idpLister)
if err != nil { if err != nil {
return err return err
} }
@ -153,7 +155,7 @@ func upstreamOIDCRefresh(
var tokens *oauth2.Token var tokens *oauth2.Token
if refreshTokenStored { if refreshTokenStored {
tokens, err = p.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken) tokens, err = p.Provider.PerformRefresh(ctx, s.OIDC.UpstreamRefreshToken)
if err != nil { if err != nil {
return errUpstreamRefreshError().WithHint( return errUpstreamRefreshError().WithHint(
"Upstream refresh failed.", "Upstream refresh failed.",
@ -174,7 +176,7 @@ func upstreamOIDCRefresh(
// way to check that the user's session was not revoked on the server. // way to check that the user's session was not revoked on the server.
// The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at // The spec is not 100% clear about whether an ID token from the refresh flow should include a nonce, and at
// least some providers do not include one, so we skip the nonce validation here (but not other validations). // least some providers do not include one, so we skip the nonce validation here (but not other validations).
validatedTokens, err := p.ValidateTokenAndMergeWithUserInfo(ctx, tokens, "", hasIDTok, accessTokenStored) validatedTokens, err := p.Provider.ValidateTokenAndMergeWithUserInfo(ctx, tokens, "", hasIDTok, accessTokenStored)
if err != nil { if err != nil {
return errUpstreamRefreshError().WithHintf( return errUpstreamRefreshError().WithHintf(
"Upstream refresh returned an invalid ID token or UserInfo response.").WithTrace(err). "Upstream refresh returned an invalid ID token or UserInfo response.").WithTrace(err).
@ -182,12 +184,22 @@ func upstreamOIDCRefresh(
} }
mergedClaims := validatedTokens.IDToken.Claims mergedClaims := validatedTokens.IDToken.Claims
// To the extent possible, check that the user's basic identity hasn't changed. oldTransformedUsername, err := getDownstreamUsernameFromPinnipedSession(session)
err = validateIdentityUnchangedSinceInitialLogin(mergedClaims, session, p.GetUsernameClaim()) if err != nil {
return err
}
oldTransformedGroups, err := getDownstreamGroupsFromPinnipedSession(session)
if err != nil { if err != nil {
return err return err
} }
// To the extent possible, check that the user's basic identity hasn't changed.
err = validateSubjectAndIssuerUnchangedSinceInitialLogin(mergedClaims, session)
if err != nil {
return err
}
var refreshedUntransformedGroups []string
groupsScope := slices.Contains(grantedScopes, oidcapi.ScopeGroups) groupsScope := slices.Contains(grantedScopes, oidcapi.ScopeGroups)
if groupsScope { //nolint:nestif if groupsScope { //nolint:nestif
// If possible, update the user's group memberships. The configured groups claim name (if there is one) may or // If possible, update the user's group memberships. The configured groups claim name (if there is one) may or
@ -197,26 +209,44 @@ func upstreamOIDCRefresh(
// If the claim is found, then use it to update the user's group membership in the session. // If the claim is found, then use it to update the user's group membership in the session.
// If the claim is not found, then we have no new information about groups, so skip updating the group membership // If the claim is not found, then we have no new information about groups, so skip updating the group membership
// and let any old groups memberships in the session remain. // and let any old groups memberships in the session remain.
refreshedGroups, err := downstreamsession.GetGroupsFromUpstreamIDToken(p, mergedClaims) refreshedUntransformedGroups, err = downstreamsession.GetGroupsFromUpstreamIDToken(p.Provider, mergedClaims)
if err != nil { if err != nil {
return errUpstreamRefreshError().WithHintf( return errUpstreamRefreshError().WithHintf(
"Upstream refresh error while extracting groups claim.").WithTrace(err). "Upstream refresh error while extracting groups claim.").WithTrace(err).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType) WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
} }
if refreshedGroups != nil {
oldGroups, err := getDownstreamGroupsFromPinnipedSession(session)
if err != nil {
return err
}
username, err := getDownstreamUsernameFromPinnipedSession(session)
if err != nil {
return err
}
warnIfGroupsChanged(ctx, oldGroups, refreshedGroups, username, clientID)
session.Fosite.Claims.Extra[oidcapi.IDTokenClaimGroups] = refreshedGroups
}
} }
// It's possible that a username wasn't returned by the upstream provider during refresh,
// but if it is, verify that the transformed version of it hasn't changed.
refreshedUntransformedUsername, hasRefreshedUntransformedUsername := getString(mergedClaims, p.Provider.GetUsernameClaim())
if !hasRefreshedUntransformedUsername {
// If we could not get a new username, then we still need the untransformed username to be able to
// run the transformations again, so fetch the original untransformed username from the session.
refreshedUntransformedUsername = s.UpstreamUsername
}
if refreshedUntransformedGroups == nil {
// If we could not get a new list of groups, then we still need the untransformed groups list to be able to
// run the transformations again, so fetch the original untransformed groups list from the session.
refreshedUntransformedGroups = s.UpstreamGroups
}
transformationResult, err := transformRefreshedIdentity(ctx,
p.Transforms,
oldTransformedUsername,
refreshedUntransformedUsername,
refreshedUntransformedGroups,
s.ProviderName,
s.ProviderType,
)
if err != nil {
return err
}
warnIfGroupsChanged(ctx, oldTransformedGroups, transformationResult.Groups, transformationResult.Username, clientID)
session.Fosite.Claims.Extra[oidcapi.IDTokenClaimGroups] = refreshedUntransformedGroups
// Upstream refresh may or may not return a new refresh token. If we got a new refresh token, then update it in // Upstream refresh may or may not return a new refresh token. If we got a new refresh token, then update it in
// the user's session. If we did not get a new refresh token, then keep the old one in the session by avoiding // the user's session. If we did not get a new refresh token, then keep the old one in the session by avoiding
// overwriting the old one. // overwriting the old one.
@ -238,7 +268,7 @@ func diffSortedGroups(oldGroups, newGroups []string) ([]string, []string) {
return added.List(), removed.List() return added.List(), removed.List()
} }
func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interface{}, session *psession.PinnipedSession, usernameClaimName string) error { func validateSubjectAndIssuerUnchangedSinceInitialLogin(mergedClaims map[string]interface{}, session *psession.PinnipedSession) error {
s := session.Custom s := session.Custom
// If we have any claims at all, we better have a subject, and it better match the previous value. // If we have any claims at all, we better have a subject, and it better match the previous value.
@ -260,19 +290,6 @@ func validateIdentityUnchangedSinceInitialLogin(mergedClaims map[string]interfac
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType) WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
} }
newUsername, hasUsername := getString(mergedClaims, usernameClaimName)
oldUsername, err := getDownstreamUsernameFromPinnipedSession(session)
if err != nil {
return err
}
// It's possible that a username wasn't returned by the upstream provider during refresh,
// but if it is, verify that it hasn't changed.
if hasUsername && oldUsername != newUsername {
return errUpstreamRefreshError().WithHintf(
"Upstream refresh failed.").WithTrace(errors.New("username in upstream refresh does not match previous value")).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
}
newIssuer, hasIssuer := getString(mergedClaims, oidcapi.IDTokenClaimIssuer) newIssuer, hasIssuer := getString(mergedClaims, oidcapi.IDTokenClaimIssuer)
// It's possible that an issuer wasn't returned by the upstream provider during refresh, // It's possible that an issuer wasn't returned by the upstream provider during refresh,
// but if it is, verify that it hasn't changed. // but if it is, verify that it hasn't changed.
@ -292,11 +309,11 @@ func getString(m map[string]interface{}, key string) (string, bool) {
func findOIDCProviderByNameAndValidateUID( func findOIDCProviderByNameAndValidateUID(
s *psession.CustomSessionData, s *psession.CustomSessionData,
providerCache oidc.UpstreamIdentityProvidersLister, idpLister provider.FederationDomainIdentityProvidersListerI,
) (provider.UpstreamOIDCIdentityProviderI, error) { ) (*provider.FederationDomainResolvedOIDCIdentityProvider, error) {
for _, p := range providerCache.GetOIDCIdentityProviders() { for _, p := range idpLister.GetOIDCIdentityProviders() {
if p.GetName() == s.ProviderName { if p.Provider.GetName() == s.ProviderName {
if p.GetResourceUID() != s.ProviderUID { if p.Provider.GetResourceUID() != s.ProviderUID {
return nil, errorsx.WithStack(errUpstreamRefreshError().WithHint( return nil, errorsx.WithStack(errUpstreamRefreshError().WithHint(
"Provider from upstream session data has changed its resource UID since authentication.")) "Provider from upstream session data has changed its resource UID since authentication."))
} }
@ -310,27 +327,25 @@ func findOIDCProviderByNameAndValidateUID(
func upstreamLDAPRefresh( func upstreamLDAPRefresh(
ctx context.Context, ctx context.Context,
providerCache oidc.UpstreamIdentityProvidersLister, idpLister provider.FederationDomainIdentityProvidersListerI,
session *psession.PinnipedSession, session *psession.PinnipedSession,
grantedScopes []string, grantedScopes []string,
clientID string, clientID string,
) error { ) error {
username, err := getDownstreamUsernameFromPinnipedSession(session) oldTransformedUsername, err := getDownstreamUsernameFromPinnipedSession(session)
if err != nil { if err != nil {
return err return err
} }
subject := session.Fosite.Claims.Subject subject := session.Fosite.Claims.Subject
var oldGroups []string var oldTransformedGroups []string
if slices.Contains(grantedScopes, oidcapi.ScopeGroups) { if slices.Contains(grantedScopes, oidcapi.ScopeGroups) {
oldGroups, err = getDownstreamGroupsFromPinnipedSession(session) oldTransformedGroups, err = getDownstreamGroupsFromPinnipedSession(session)
if err != nil { if err != nil {
return err return err
} }
} }
s := session.Custom s := session.Custom
// if you have neither a valid ldap session config nor a valid active directory session config
validLDAP := s.ProviderType == psession.ProviderTypeLDAP && s.LDAP != nil && s.LDAP.UserDN != "" validLDAP := s.ProviderType == psession.ProviderTypeLDAP && s.LDAP != nil && s.LDAP.UserDN != ""
validAD := s.ProviderType == psession.ProviderTypeActiveDirectory && s.ActiveDirectory != nil && s.ActiveDirectory.UserDN != "" validAD := s.ProviderType == psession.ProviderTypeActiveDirectory && s.ActiveDirectory != nil && s.ActiveDirectory.UserDN != ""
if !(validLDAP || validAD) { if !(validLDAP || validAD) {
@ -344,20 +359,19 @@ func upstreamLDAPRefresh(
additionalAttributes = s.ActiveDirectory.ExtraRefreshAttributes additionalAttributes = s.ActiveDirectory.ExtraRefreshAttributes
} }
// get ldap/ad provider out of cache p, dn, err := findLDAPProviderByNameAndValidateUID(s, idpLister)
p, dn, err := findLDAPProviderByNameAndValidateUID(s, providerCache)
if err != nil { if err != nil {
return err return err
} }
if session.IDTokenClaims().AuthTime.IsZero() { if session.IDTokenClaims().AuthTime.IsZero() {
return errorsx.WithStack(errMissingUpstreamSessionInternalError()) return errorsx.WithStack(errMissingUpstreamSessionInternalError())
} }
// run PerformRefresh
groups, err := p.PerformRefresh(ctx, provider.RefreshAttributes{ refreshedUntransformedGroups, err := p.Provider.PerformRefresh(ctx, upstreamprovider.RefreshAttributes{
Username: username, Username: s.UpstreamUsername,
Subject: subject, Subject: subject,
DN: dn, DN: dn,
Groups: oldGroups, Groups: s.UpstreamGroups,
AdditionalAttributes: additionalAttributes, AdditionalAttributes: additionalAttributes,
GrantedScopes: grantedScopes, GrantedScopes: grantedScopes,
}) })
@ -366,33 +380,79 @@ func upstreamLDAPRefresh(
"Upstream refresh failed.").WithTrace(err). "Upstream refresh failed.").WithTrace(err).
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType) WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)
} }
transformationResult, err := transformRefreshedIdentity(ctx,
p.Transforms,
oldTransformedUsername,
s.UpstreamUsername,
refreshedUntransformedGroups,
s.ProviderName,
s.ProviderType,
)
if err != nil {
return err
}
groupsScope := slices.Contains(grantedScopes, oidcapi.ScopeGroups) groupsScope := slices.Contains(grantedScopes, oidcapi.ScopeGroups)
if groupsScope { if groupsScope {
warnIfGroupsChanged(ctx, oldGroups, groups, username, clientID) warnIfGroupsChanged(ctx, oldTransformedGroups, transformationResult.Groups, transformationResult.Username, clientID)
// Replace the old value with the new value. // Replace the old value with the new value.
session.Fosite.Claims.Extra[oidcapi.IDTokenClaimGroups] = groups session.Fosite.Claims.Extra[oidcapi.IDTokenClaimGroups] = transformationResult.Groups
} }
return nil return nil
} }
func transformRefreshedIdentity(
ctx context.Context,
transforms *idtransform.TransformationPipeline,
oldTransformedUsername string,
upstreamUsername string,
upstreamGroups []string,
providerName string,
providerType psession.ProviderType,
) (*idtransform.TransformationResult, error) {
transformationResult, err := transforms.Evaluate(ctx, upstreamUsername, upstreamGroups)
if err != nil {
return nil, errUpstreamRefreshError().WithHintf(
"Upstream refresh error while applying configured identity transformations.").
WithTrace(err).
WithDebugf("provider name: %q, provider type: %q", providerName, providerType)
}
if !transformationResult.AuthenticationAllowed {
return nil, errUpstreamRefreshError().WithHintf(
"Upstream refresh rejected by configured identity policy: %s.", transformationResult.RejectedAuthenticationMessage).
WithDebugf("provider name: %q, provider type: %q", providerName, providerType)
}
if oldTransformedUsername != transformationResult.Username {
return nil, errUpstreamRefreshError().WithHintf(
"Upstream refresh failed.").
WithTrace(errors.New("username in upstream refresh does not match previous value")).
WithDebugf("provider name: %q, provider type: %q", providerName, providerType)
}
return transformationResult, nil
}
func findLDAPProviderByNameAndValidateUID( func findLDAPProviderByNameAndValidateUID(
s *psession.CustomSessionData, s *psession.CustomSessionData,
providerCache oidc.UpstreamIdentityProvidersLister, idpLister provider.FederationDomainIdentityProvidersListerI,
) (provider.UpstreamLDAPIdentityProviderI, string, error) { ) (*provider.FederationDomainResolvedLDAPIdentityProvider, string, error) {
var providers []provider.UpstreamLDAPIdentityProviderI var providers []*provider.FederationDomainResolvedLDAPIdentityProvider
var dn string var dn string
if s.ProviderType == psession.ProviderTypeLDAP { if s.ProviderType == psession.ProviderTypeLDAP {
providers = providerCache.GetLDAPIdentityProviders() providers = idpLister.GetLDAPIdentityProviders()
dn = s.LDAP.UserDN dn = s.LDAP.UserDN
} else if s.ProviderType == psession.ProviderTypeActiveDirectory { } else if s.ProviderType == psession.ProviderTypeActiveDirectory {
providers = providerCache.GetActiveDirectoryIdentityProviders() providers = idpLister.GetActiveDirectoryIdentityProviders()
dn = s.ActiveDirectory.UserDN dn = s.ActiveDirectory.UserDN
} }
for _, p := range providers { for _, p := range providers {
if p.GetName() == s.ProviderName { if p.Provider.GetName() == s.ProviderName {
if p.GetResourceUID() != s.ProviderUID { if p.Provider.GetResourceUID() != s.ProviderUID {
return nil, "", errorsx.WithStack(errUpstreamRefreshError().WithHint( return nil, "", errorsx.WithStack(errUpstreamRefreshError().WithHint(
"Provider from upstream session data has changed its resource UID since authentication."). "Provider from upstream session data has changed its resource UID since authentication.").
WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType)) WithDebugf("provider name: %q, provider type: %q", s.ProviderName, s.ProviderType))

View File

@ -1,4 +1,4 @@
// Copyright 2021-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2021-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package psession package psession
@ -32,6 +32,18 @@ type CustomSessionData struct {
// all users must have a username. // all users must have a username.
Username string `json:"username"` Username string `json:"username"`
// UpstreamUsername is the username from the upstream identity provider during the user's initial login before
// identity transformations were applied. We store this so that we can still reapply identity transformations
// during refresh flows even when an upstream OIDC provider does not return the username again during the upstream
// refresh, and so we can validate that same untransformed username was found during an LDAP refresh.
UpstreamUsername string `json:"upstreamUsername"`
// UpstreamGroups is the groups list from the upstream identity provider during the user's initial login before
// identity transformations were applied. We store this so that we can still reapply identity transformations
// during refresh flows even when an OIDC provider does not return the groups again during the upstream
// refresh, and when the LDAP search was configured to skip group refreshes.
UpstreamGroups []string `json:"upstreamGroups"`
// The Kubernetes resource UID of the identity provider CRD for the upstream IDP used to start this session. // The Kubernetes resource UID of the identity provider CRD for the upstream IDP used to start this session.
// This should be validated again upon downstream refresh to make sure that we are not refreshing against // This should be validated again upon downstream refresh to make sure that we are not refreshing against
// a different identity provider CRD which just happens to have the same name. // a different identity provider CRD which just happens to have the same name.
@ -41,11 +53,12 @@ type CustomSessionData struct {
// The Kubernetes resource name of the identity provider CRD for the upstream IDP used to start this session. // The Kubernetes resource name of the identity provider CRD for the upstream IDP used to start this session.
// Used during a downstream refresh to decide which upstream to refresh. // Used during a downstream refresh to decide which upstream to refresh.
// Also used to decide which of the pointer types below should be used. // Also used by the session storage garbage collector to decide which upstream to use for token revocation.
ProviderName string `json:"providerName"` ProviderName string `json:"providerName"`
// The type of the identity provider for the upstream IDP used to start this session. // The type of the identity provider for the upstream IDP used to start this session.
// Used during a downstream refresh to decide which upstream to refresh. // Used during a downstream refresh to decide which upstream to refresh.
// Also used to decide which of the pointer types below should be used.
ProviderType ProviderType `json:"providerType"` ProviderType ProviderType `json:"providerType"`
// Warnings that were encountered at some point during login that should be emitted to the client. // Warnings that were encountered at some point during login that should be emitted to the client.
@ -55,8 +68,10 @@ type CustomSessionData struct {
// Only used when ProviderType == "oidc". // Only used when ProviderType == "oidc".
OIDC *OIDCSessionData `json:"oidc,omitempty"` OIDC *OIDCSessionData `json:"oidc,omitempty"`
// Only used when ProviderType == "ldap".
LDAP *LDAPSessionData `json:"ldap,omitempty"` LDAP *LDAPSessionData `json:"ldap,omitempty"`
// Only used when ProviderType == "activedirectory".
ActiveDirectory *ActiveDirectorySessionData `json:"activedirectory,omitempty"` ActiveDirectory *ActiveDirectorySessionData `json:"activedirectory,omitempty"`
} }

View File

@ -169,6 +169,9 @@ func prepareControllers(
clock.RealClock{}, clock.RealClock{},
pinnipedClient, pinnipedClient,
federationDomainInformer, federationDomainInformer,
pinnipedInformers.IDP().V1alpha1().OIDCIdentityProviders(),
pinnipedInformers.IDP().V1alpha1().LDAPIdentityProviders(),
pinnipedInformers.IDP().V1alpha1().ActiveDirectoryIdentityProviders(),
controllerlib.WithInformer, controllerlib.WithInformer,
), ),
singletonWorker, singletonWorker,

View File

@ -35,6 +35,7 @@ import (
pkce2 "go.pinniped.dev/internal/fositestorage/pkce" pkce2 "go.pinniped.dev/internal/fositestorage/pkce"
"go.pinniped.dev/internal/fositestoragei" "go.pinniped.dev/internal/fositestoragei"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/psession" "go.pinniped.dev/internal/psession"
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
@ -77,7 +78,7 @@ type PerformRefreshArgs struct {
type RevokeTokenArgs struct { type RevokeTokenArgs struct {
Ctx context.Context Ctx context.Context
Token string Token string
TokenType provider.RevocableTokenType TokenType upstreamprovider.RevocableTokenType
} }
// ValidateTokenAndMergeWithUserInfoArgs is used to spy on calls to // ValidateTokenAndMergeWithUserInfoArgs is used to spy on calls to
@ -93,7 +94,7 @@ type ValidateTokenAndMergeWithUserInfoArgs struct {
type ValidateRefreshArgs struct { type ValidateRefreshArgs struct {
Ctx context.Context Ctx context.Context
Tok *oauth2.Token Tok *oauth2.Token
StoredAttributes provider.RefreshAttributes StoredAttributes upstreamprovider.RefreshAttributes
} }
type TestUpstreamLDAPIdentityProvider struct { type TestUpstreamLDAPIdentityProvider struct {
@ -107,7 +108,7 @@ type TestUpstreamLDAPIdentityProvider struct {
PerformRefreshGroups []string PerformRefreshGroups []string
} }
var _ provider.UpstreamLDAPIdentityProviderI = &TestUpstreamLDAPIdentityProvider{} var _ upstreamprovider.UpstreamLDAPIdentityProviderI = &TestUpstreamLDAPIdentityProvider{}
func (u *TestUpstreamLDAPIdentityProvider) GetResourceUID() types.UID { func (u *TestUpstreamLDAPIdentityProvider) GetResourceUID() types.UID {
return u.ResourceUID return u.ResourceUID
@ -125,7 +126,7 @@ func (u *TestUpstreamLDAPIdentityProvider) GetURL() *url.URL {
return u.URL return u.URL
} }
func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.RefreshAttributes) ([]string, error) { func (u *TestUpstreamLDAPIdentityProvider) PerformRefresh(ctx context.Context, storedRefreshAttributes upstreamprovider.RefreshAttributes) ([]string, error) {
if u.performRefreshArgs == nil { if u.performRefreshArgs == nil {
u.performRefreshArgs = make([]*PerformRefreshArgs, 0) u.performRefreshArgs = make([]*PerformRefreshArgs, 0)
} }
@ -182,7 +183,7 @@ type TestUpstreamOIDCIdentityProvider struct {
PerformRefreshFunc func(ctx context.Context, refreshToken string) (*oauth2.Token, error) PerformRefreshFunc func(ctx context.Context, refreshToken string) (*oauth2.Token, error)
RevokeTokenFunc func(ctx context.Context, refreshToken string, tokenType provider.RevocableTokenType) error RevokeTokenFunc func(ctx context.Context, refreshToken string, tokenType upstreamprovider.RevocableTokenType) error
ValidateTokenAndMergeWithUserInfoFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) ValidateTokenAndMergeWithUserInfoFunc func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error)
@ -198,7 +199,7 @@ type TestUpstreamOIDCIdentityProvider struct {
validateTokenAndMergeWithUserInfoArgs []*ValidateTokenAndMergeWithUserInfoArgs validateTokenAndMergeWithUserInfoArgs []*ValidateTokenAndMergeWithUserInfoArgs
} }
var _ provider.UpstreamOIDCIdentityProviderI = &TestUpstreamOIDCIdentityProvider{} var _ upstreamprovider.UpstreamOIDCIdentityProviderI = &TestUpstreamOIDCIdentityProvider{}
func (u *TestUpstreamOIDCIdentityProvider) GetResourceUID() types.UID { func (u *TestUpstreamOIDCIdentityProvider) GetResourceUID() types.UID {
return u.ResourceUID return u.ResourceUID
@ -302,7 +303,7 @@ func (u *TestUpstreamOIDCIdentityProvider) PerformRefresh(ctx context.Context, r
return u.PerformRefreshFunc(ctx, refreshToken) return u.PerformRefreshFunc(ctx, refreshToken)
} }
func (u *TestUpstreamOIDCIdentityProvider) RevokeToken(ctx context.Context, token string, tokenType provider.RevocableTokenType) error { func (u *TestUpstreamOIDCIdentityProvider) RevokeToken(ctx context.Context, token string, tokenType upstreamprovider.RevocableTokenType) error {
if u.revokeTokenArgs == nil { if u.revokeTokenArgs == nil {
u.revokeTokenArgs = make([]*RevokeTokenArgs, 0) u.revokeTokenArgs = make([]*RevokeTokenArgs, 0)
} }
@ -387,21 +388,21 @@ func (b *UpstreamIDPListerBuilder) WithActiveDirectory(upstreamActiveDirectoryId
func (b *UpstreamIDPListerBuilder) Build() provider.DynamicUpstreamIDPProvider { func (b *UpstreamIDPListerBuilder) Build() provider.DynamicUpstreamIDPProvider {
idpProvider := provider.NewDynamicUpstreamIDPProvider() idpProvider := provider.NewDynamicUpstreamIDPProvider()
oidcUpstreams := make([]provider.UpstreamOIDCIdentityProviderI, len(b.upstreamOIDCIdentityProviders)) oidcUpstreams := make([]upstreamprovider.UpstreamOIDCIdentityProviderI, len(b.upstreamOIDCIdentityProviders))
for i := range b.upstreamOIDCIdentityProviders { for i := range b.upstreamOIDCIdentityProviders {
oidcUpstreams[i] = provider.UpstreamOIDCIdentityProviderI(b.upstreamOIDCIdentityProviders[i]) oidcUpstreams[i] = upstreamprovider.UpstreamOIDCIdentityProviderI(b.upstreamOIDCIdentityProviders[i])
} }
idpProvider.SetOIDCIdentityProviders(oidcUpstreams) idpProvider.SetOIDCIdentityProviders(oidcUpstreams)
ldapUpstreams := make([]provider.UpstreamLDAPIdentityProviderI, len(b.upstreamLDAPIdentityProviders)) ldapUpstreams := make([]upstreamprovider.UpstreamLDAPIdentityProviderI, len(b.upstreamLDAPIdentityProviders))
for i := range b.upstreamLDAPIdentityProviders { for i := range b.upstreamLDAPIdentityProviders {
ldapUpstreams[i] = provider.UpstreamLDAPIdentityProviderI(b.upstreamLDAPIdentityProviders[i]) ldapUpstreams[i] = upstreamprovider.UpstreamLDAPIdentityProviderI(b.upstreamLDAPIdentityProviders[i])
} }
idpProvider.SetLDAPIdentityProviders(ldapUpstreams) idpProvider.SetLDAPIdentityProviders(ldapUpstreams)
adUpstreams := make([]provider.UpstreamLDAPIdentityProviderI, len(b.upstreamActiveDirectoryIdentityProviders)) adUpstreams := make([]upstreamprovider.UpstreamLDAPIdentityProviderI, len(b.upstreamActiveDirectoryIdentityProviders))
for i := range b.upstreamActiveDirectoryIdentityProviders { for i := range b.upstreamActiveDirectoryIdentityProviders {
adUpstreams[i] = provider.UpstreamLDAPIdentityProviderI(b.upstreamActiveDirectoryIdentityProviders[i]) adUpstreams[i] = upstreamprovider.UpstreamLDAPIdentityProviderI(b.upstreamActiveDirectoryIdentityProviders[i])
} }
idpProvider.SetActiveDirectoryIdentityProviders(adUpstreams) idpProvider.SetActiveDirectoryIdentityProviders(adUpstreams)
@ -822,7 +823,7 @@ func (u *TestUpstreamOIDCIdentityProviderBuilder) Build() *TestUpstreamOIDCIdent
} }
return u.refreshedTokens, nil return u.refreshedTokens, nil
}, },
RevokeTokenFunc: func(ctx context.Context, refreshToken string, tokenType provider.RevocableTokenType) error { RevokeTokenFunc: func(ctx context.Context, refreshToken string, tokenType upstreamprovider.RevocableTokenType) error {
return u.revokeTokenErr return u.revokeTokenErr
}, },
ValidateTokenAndMergeWithUserInfoFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) { ValidateTokenAndMergeWithUserInfoFunc: func(ctx context.Context, tok *oauth2.Token, expectedIDTokenNonce nonce.Nonce) (*oidctypes.Token, error) {

View File

@ -28,7 +28,7 @@ import (
"go.pinniped.dev/internal/crypto/ptls" "go.pinniped.dev/internal/crypto/ptls"
"go.pinniped.dev/internal/endpointaddr" "go.pinniped.dev/internal/endpointaddr"
"go.pinniped.dev/internal/oidc/downstreamsession" "go.pinniped.dev/internal/oidc/downstreamsession"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
) )
@ -120,7 +120,7 @@ type ProviderConfig struct {
GroupAttributeParsingOverrides map[string]func(*ldap.Entry) (string, error) GroupAttributeParsingOverrides map[string]func(*ldap.Entry) (string, error)
// RefreshAttributeChecks are extra checks that attributes in a refresh response are as expected. // RefreshAttributeChecks are extra checks that attributes in a refresh response are as expected.
RefreshAttributeChecks map[string]func(*ldap.Entry, provider.RefreshAttributes) error RefreshAttributeChecks map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error
} }
// UserSearchConfig contains information about how to search for users in the upstream LDAP IDP. // UserSearchConfig contains information about how to search for users in the upstream LDAP IDP.
@ -167,7 +167,7 @@ type Provider struct {
c ProviderConfig c ProviderConfig
} }
var _ provider.UpstreamLDAPIdentityProviderI = &Provider{} var _ upstreamprovider.UpstreamLDAPIdentityProviderI = &Provider{}
var _ authenticators.UserAuthenticator = &Provider{} var _ authenticators.UserAuthenticator = &Provider{}
// New creates a Provider. The config is not a pointer to ensure that a copy of the config is created, // New creates a Provider. The config is not a pointer to ensure that a copy of the config is created,
@ -188,7 +188,7 @@ func closeAndLogError(conn Conn, doingWhat string) {
} }
} }
func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes provider.RefreshAttributes) ([]string, error) { func (p *Provider) PerformRefresh(ctx context.Context, storedRefreshAttributes upstreamprovider.RefreshAttributes) ([]string, error) {
t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetName()}) t := trace.FromContext(ctx).Nest("slow ldap refresh attempt", trace.Field{Key: "providerName", Value: p.GetName()})
defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches defer t.LogIfLong(500 * time.Millisecond) // to help users debug slow LDAP searches
userDN := storedRefreshAttributes.DN userDN := storedRefreshAttributes.DN

View File

@ -26,7 +26,7 @@ import (
"go.pinniped.dev/internal/crypto/ptls" "go.pinniped.dev/internal/crypto/ptls"
"go.pinniped.dev/internal/endpointaddr" "go.pinniped.dev/internal/endpointaddr"
"go.pinniped.dev/internal/mocks/mockldapconn" "go.pinniped.dev/internal/mocks/mockldapconn"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
"go.pinniped.dev/internal/testutil/tlsassertions" "go.pinniped.dev/internal/testutil/tlsassertions"
"go.pinniped.dev/internal/testutil/tlsserver" "go.pinniped.dev/internal/testutil/tlsserver"
@ -661,8 +661,8 @@ func TestEndUserAuthentication(t *testing.T) {
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
providerConfig: providerConfig(func(p *ProviderConfig) { providerConfig: providerConfig(func(p *ProviderConfig) {
p.RefreshAttributeChecks = map[string]func(entry *ldap.Entry, attributes provider.RefreshAttributes) error{ p.RefreshAttributeChecks = map[string]func(entry *ldap.Entry, attributes upstreamprovider.RefreshAttributes) error{
"some-attribute-to-check-during-refresh": func(entry *ldap.Entry, attributes provider.RefreshAttributes) error { "some-attribute-to-check-during-refresh": func(entry *ldap.Entry, attributes upstreamprovider.RefreshAttributes) error {
return nil return nil
}, },
} }
@ -699,8 +699,8 @@ func TestEndUserAuthentication(t *testing.T) {
username: testUpstreamUsername, username: testUpstreamUsername,
password: testUpstreamPassword, password: testUpstreamPassword,
providerConfig: providerConfig(func(p *ProviderConfig) { providerConfig: providerConfig(func(p *ProviderConfig) {
p.RefreshAttributeChecks = map[string]func(entry *ldap.Entry, attributes provider.RefreshAttributes) error{ p.RefreshAttributeChecks = map[string]func(entry *ldap.Entry, attributes upstreamprovider.RefreshAttributes) error{
"some-attribute-to-check-during-refresh": func(entry *ldap.Entry, attributes provider.RefreshAttributes) error { "some-attribute-to-check-during-refresh": func(entry *ldap.Entry, attributes upstreamprovider.RefreshAttributes) error {
return nil return nil
}, },
} }
@ -1575,8 +1575,8 @@ func TestUpstreamRefresh(t *testing.T) {
Filter: testGroupSearchFilter, Filter: testGroupSearchFilter,
GroupNameAttribute: testGroupSearchGroupNameAttribute, GroupNameAttribute: testGroupSearchGroupNameAttribute,
}, },
RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.RefreshAttributes) error{ RefreshAttributeChecks: map[string]func(*ldap.Entry, upstreamprovider.RefreshAttributes) error{
pwdLastSetAttribute: func(*ldap.Entry, provider.RefreshAttributes) error { return nil }, pwdLastSetAttribute: func(*ldap.Entry, upstreamprovider.RefreshAttributes) error { return nil },
}, },
} }
if editFunc != nil { if editFunc != nil {
@ -2280,7 +2280,7 @@ func TestUpstreamRefresh(t *testing.T) {
initialPwdLastSetEncoded := base64.RawURLEncoding.EncodeToString([]byte("132801740800000000")) initialPwdLastSetEncoded := base64.RawURLEncoding.EncodeToString([]byte("132801740800000000"))
ldapProvider := New(*tt.providerConfig) ldapProvider := New(*tt.providerConfig)
subject := "ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=c29tZS11cHN0cmVhbS11aWQtdmFsdWU" subject := "ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=c29tZS11cHN0cmVhbS11aWQtdmFsdWU"
groups, err := ldapProvider.PerformRefresh(context.Background(), provider.RefreshAttributes{ groups, err := ldapProvider.PerformRefresh(context.Background(), upstreamprovider.RefreshAttributes{
Username: testUserSearchResultUsernameAttributeValue, Username: testUserSearchResultUsernameAttributeValue,
Subject: subject, Subject: subject,
DN: tt.refreshUserDN, DN: tt.refreshUserDN,

View File

@ -23,13 +23,14 @@ import (
oidcapi "go.pinniped.dev/generated/latest/apis/supervisor/oidc" oidcapi "go.pinniped.dev/generated/latest/apis/supervisor/oidc"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/pkce"
) )
func New(config *oauth2.Config, provider *coreosoidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { func New(config *oauth2.Config, provider *coreosoidc.Provider, client *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI {
return &ProviderConfig{Config: config, Provider: provider, Client: client} return &ProviderConfig{Config: config, Provider: provider, Client: client}
} }
@ -52,7 +53,7 @@ type ProviderConfig struct {
} }
} }
var _ provider.UpstreamOIDCIdentityProviderI = (*ProviderConfig)(nil) var _ upstreamprovider.UpstreamOIDCIdentityProviderI = (*ProviderConfig)(nil)
func (p *ProviderConfig) GetResourceUID() types.UID { func (p *ProviderConfig) GetResourceUID() types.UID {
return p.ResourceUID return p.ResourceUID
@ -160,7 +161,7 @@ func (p *ProviderConfig) PerformRefresh(ctx context.Context, refreshToken string
// It may return an error wrapped by a RetryableRevocationError, which is an error indicating that it may // It may return an error wrapped by a RetryableRevocationError, which is an error indicating that it may
// be worth trying to revoke the same token again later. Any other error returned should be assumed to // be worth trying to revoke the same token again later. Any other error returned should be assumed to
// represent an error such that it is not worth retrying revocation later, even though revocation failed. // represent an error such that it is not worth retrying revocation later, even though revocation failed.
func (p *ProviderConfig) RevokeToken(ctx context.Context, token string, tokenType provider.RevocableTokenType) error { func (p *ProviderConfig) RevokeToken(ctx context.Context, token string, tokenType upstreamprovider.RevocableTokenType) error {
if p.RevocationURL == nil { if p.RevocationURL == nil {
plog.Trace("RevokeToken() was called but upstream provider has no available revocation endpoint", plog.Trace("RevokeToken() was called but upstream provider has no available revocation endpoint",
"providerName", p.Name, "providerName", p.Name,
@ -188,7 +189,7 @@ func (p *ProviderConfig) RevokeToken(ctx context.Context, token string, tokenTyp
func (p *ProviderConfig) tryRevokeToken( func (p *ProviderConfig) tryRevokeToken(
ctx context.Context, ctx context.Context,
token string, token string,
tokenType provider.RevocableTokenType, tokenType upstreamprovider.RevocableTokenType,
useBasicAuth bool, useBasicAuth bool,
) (tryAnotherClientAuthMethod bool, err error) { ) (tryAnotherClientAuthMethod bool, err error) {
clientID := p.Config.ClientID clientID := p.Config.ClientID

View File

@ -25,6 +25,7 @@ import (
"go.pinniped.dev/internal/mocks/mockkeyset" "go.pinniped.dev/internal/mocks/mockkeyset"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/oidctypes"
@ -484,7 +485,7 @@ func TestProviderConfig(t *testing.T) {
t.Run("RevokeToken", func(t *testing.T) { t.Run("RevokeToken", func(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
tokenType provider.RevocableTokenType tokenType upstreamprovider.RevocableTokenType
nilRevocationURL bool nilRevocationURL bool
unreachableServer bool unreachableServer bool
returnStatusCodes []int returnStatusCodes []int
@ -496,33 +497,33 @@ func TestProviderConfig(t *testing.T) {
}{ }{
{ {
name: "success without calling the server when there is no revocation URL set for refresh token", name: "success without calling the server when there is no revocation URL set for refresh token",
tokenType: provider.RefreshTokenType, tokenType: upstreamprovider.RefreshTokenType,
nilRevocationURL: true, nilRevocationURL: true,
wantNumRequests: 0, wantNumRequests: 0,
}, },
{ {
name: "success without calling the server when there is no revocation URL set for access token", name: "success without calling the server when there is no revocation URL set for access token",
tokenType: provider.AccessTokenType, tokenType: upstreamprovider.AccessTokenType,
nilRevocationURL: true, nilRevocationURL: true,
wantNumRequests: 0, wantNumRequests: 0,
}, },
{ {
name: "success when the server returns 200 OK on the first call for refresh token", name: "success when the server returns 200 OK on the first call for refresh token",
tokenType: provider.RefreshTokenType, tokenType: upstreamprovider.RefreshTokenType,
returnStatusCodes: []int{http.StatusOK}, returnStatusCodes: []int{http.StatusOK},
wantNumRequests: 1, wantNumRequests: 1,
wantTokenTypeHint: "refresh_token", wantTokenTypeHint: "refresh_token",
}, },
{ {
name: "success when the server returns 200 OK on the first call for access token", name: "success when the server returns 200 OK on the first call for access token",
tokenType: provider.AccessTokenType, tokenType: upstreamprovider.AccessTokenType,
returnStatusCodes: []int{http.StatusOK}, returnStatusCodes: []int{http.StatusOK},
wantNumRequests: 1, wantNumRequests: 1,
wantTokenTypeHint: "access_token", wantTokenTypeHint: "access_token",
}, },
{ {
name: "success when the server returns 400 Bad Request on the first call due to client auth, then 200 OK on second call for refresh token", name: "success when the server returns 400 Bad Request on the first call due to client auth, then 200 OK on second call for refresh token",
tokenType: provider.RefreshTokenType, tokenType: upstreamprovider.RefreshTokenType,
returnStatusCodes: []int{http.StatusBadRequest, http.StatusOK}, returnStatusCodes: []int{http.StatusBadRequest, http.StatusOK},
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 defines this as the error for client auth failure // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 defines this as the error for client auth failure
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`}, returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`},
@ -531,7 +532,7 @@ func TestProviderConfig(t *testing.T) {
}, },
{ {
name: "success when the server returns 400 Bad Request on the first call due to client auth, then 200 OK on second call for access token", name: "success when the server returns 400 Bad Request on the first call due to client auth, then 200 OK on second call for access token",
tokenType: provider.AccessTokenType, tokenType: upstreamprovider.AccessTokenType,
returnStatusCodes: []int{http.StatusBadRequest, http.StatusOK}, returnStatusCodes: []int{http.StatusBadRequest, http.StatusOK},
// https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 defines this as the error for client auth failure // https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 defines this as the error for client auth failure
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`}, returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`},
@ -540,7 +541,7 @@ func TestProviderConfig(t *testing.T) {
}, },
{ {
name: "error when the server returns 400 Bad Request on the first call due to client auth, then any 400 error on second call", name: "error when the server returns 400 Bad Request on the first call due to client auth, then any 400 error on second call",
tokenType: provider.RefreshTokenType, tokenType: upstreamprovider.RefreshTokenType,
returnStatusCodes: []int{http.StatusBadRequest, http.StatusBadRequest}, returnStatusCodes: []int{http.StatusBadRequest, http.StatusBadRequest},
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, `{ "error":"anything", "error_description":"unhappy" }`}, returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, `{ "error":"anything", "error_description":"unhappy" }`},
wantErr: testutil.WantExactErrorString(`server responded with status 400 with body: { "error":"anything", "error_description":"unhappy" }`), wantErr: testutil.WantExactErrorString(`server responded with status 400 with body: { "error":"anything", "error_description":"unhappy" }`),
@ -550,7 +551,7 @@ func TestProviderConfig(t *testing.T) {
}, },
{ {
name: "error when the server returns 400 Bad Request with bad JSON body on the first call", name: "error when the server returns 400 Bad Request with bad JSON body on the first call",
tokenType: provider.RefreshTokenType, tokenType: upstreamprovider.RefreshTokenType,
returnStatusCodes: []int{http.StatusBadRequest}, returnStatusCodes: []int{http.StatusBadRequest},
returnErrBodies: []string{`invalid JSON body`}, returnErrBodies: []string{`invalid JSON body`},
wantErr: testutil.WantExactErrorString(`error parsing response body "invalid JSON body" on response with status code 400: invalid character 'i' looking for beginning of value`), wantErr: testutil.WantExactErrorString(`error parsing response body "invalid JSON body" on response with status code 400: invalid character 'i' looking for beginning of value`),
@ -560,7 +561,7 @@ func TestProviderConfig(t *testing.T) {
}, },
{ {
name: "error when the server returns 400 Bad Request with empty body", name: "error when the server returns 400 Bad Request with empty body",
tokenType: provider.RefreshTokenType, tokenType: upstreamprovider.RefreshTokenType,
returnStatusCodes: []int{http.StatusBadRequest}, returnStatusCodes: []int{http.StatusBadRequest},
returnErrBodies: []string{``}, returnErrBodies: []string{``},
wantErr: testutil.WantExactErrorString(`error parsing response body "" on response with status code 400: unexpected end of JSON input`), wantErr: testutil.WantExactErrorString(`error parsing response body "" on response with status code 400: unexpected end of JSON input`),
@ -570,7 +571,7 @@ func TestProviderConfig(t *testing.T) {
}, },
{ {
name: "error when the server returns 400 Bad Request on the first call due to client auth, then any other error on second call", name: "error when the server returns 400 Bad Request on the first call due to client auth, then any other error on second call",
tokenType: provider.RefreshTokenType, tokenType: upstreamprovider.RefreshTokenType,
returnStatusCodes: []int{http.StatusBadRequest, http.StatusForbidden}, returnStatusCodes: []int{http.StatusBadRequest, http.StatusForbidden},
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, ""}, returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, ""},
wantErr: testutil.WantExactErrorString("server responded with status 403"), wantErr: testutil.WantExactErrorString("server responded with status 403"),
@ -580,7 +581,7 @@ func TestProviderConfig(t *testing.T) {
}, },
{ {
name: "error when server returns any other 400 error on first call", name: "error when server returns any other 400 error on first call",
tokenType: provider.RefreshTokenType, tokenType: upstreamprovider.RefreshTokenType,
returnStatusCodes: []int{http.StatusBadRequest}, returnStatusCodes: []int{http.StatusBadRequest},
returnErrBodies: []string{`{ "error":"anything_else", "error_description":"unhappy" }`}, returnErrBodies: []string{`{ "error":"anything_else", "error_description":"unhappy" }`},
wantErr: testutil.WantExactErrorString(`server responded with status 400 with body: { "error":"anything_else", "error_description":"unhappy" }`), wantErr: testutil.WantExactErrorString(`server responded with status 400 with body: { "error":"anything_else", "error_description":"unhappy" }`),
@ -590,7 +591,7 @@ func TestProviderConfig(t *testing.T) {
}, },
{ {
name: "error when server returns any other error aside from 400 on first call", name: "error when server returns any other error aside from 400 on first call",
tokenType: provider.RefreshTokenType, tokenType: upstreamprovider.RefreshTokenType,
returnStatusCodes: []int{http.StatusForbidden}, returnStatusCodes: []int{http.StatusForbidden},
returnErrBodies: []string{""}, returnErrBodies: []string{""},
wantErr: testutil.WantExactErrorString("server responded with status 403"), wantErr: testutil.WantExactErrorString("server responded with status 403"),
@ -600,7 +601,7 @@ func TestProviderConfig(t *testing.T) {
}, },
{ {
name: "retryable error when server returns 503 on first call", name: "retryable error when server returns 503 on first call",
tokenType: provider.RefreshTokenType, tokenType: upstreamprovider.RefreshTokenType,
returnStatusCodes: []int{http.StatusServiceUnavailable}, // 503 returnStatusCodes: []int{http.StatusServiceUnavailable}, // 503
returnErrBodies: []string{""}, returnErrBodies: []string{""},
wantErr: testutil.WantExactErrorString("retryable revocation error: server responded with status 503"), wantErr: testutil.WantExactErrorString("retryable revocation error: server responded with status 503"),
@ -610,7 +611,7 @@ func TestProviderConfig(t *testing.T) {
}, },
{ {
name: "retryable error when the server returns 400 Bad Request on the first call due to client auth, then 503 on second call", name: "retryable error when the server returns 400 Bad Request on the first call due to client auth, then 503 on second call",
tokenType: provider.AccessTokenType, tokenType: upstreamprovider.AccessTokenType,
returnStatusCodes: []int{http.StatusBadRequest, http.StatusServiceUnavailable}, // 400, 503 returnStatusCodes: []int{http.StatusBadRequest, http.StatusServiceUnavailable}, // 400, 503
returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, ""}, returnErrBodies: []string{`{ "error":"invalid_client", "error_description":"unhappy" }`, ""},
wantErr: testutil.WantExactErrorString("retryable revocation error: server responded with status 503"), wantErr: testutil.WantExactErrorString("retryable revocation error: server responded with status 503"),
@ -620,7 +621,7 @@ func TestProviderConfig(t *testing.T) {
}, },
{ {
name: "retryable error when server returns any 5xx status on first call, testing lower bound of 5xx range", name: "retryable error when server returns any 5xx status on first call, testing lower bound of 5xx range",
tokenType: provider.RefreshTokenType, tokenType: upstreamprovider.RefreshTokenType,
returnStatusCodes: []int{http.StatusInternalServerError}, // 500 returnStatusCodes: []int{http.StatusInternalServerError}, // 500
returnErrBodies: []string{""}, returnErrBodies: []string{""},
wantErr: testutil.WantExactErrorString("retryable revocation error: server responded with status 500"), wantErr: testutil.WantExactErrorString("retryable revocation error: server responded with status 500"),
@ -630,7 +631,7 @@ func TestProviderConfig(t *testing.T) {
}, },
{ {
name: "retryable error when server returns any 5xx status on first call, testing upper bound of 5xx range", name: "retryable error when server returns any 5xx status on first call, testing upper bound of 5xx range",
tokenType: provider.RefreshTokenType, tokenType: upstreamprovider.RefreshTokenType,
returnStatusCodes: []int{599}, // not defined by an RFC, but sometimes considered Network Connect Timeout Error returnStatusCodes: []int{599}, // not defined by an RFC, but sometimes considered Network Connect Timeout Error
returnErrBodies: []string{""}, returnErrBodies: []string{""},
wantErr: testutil.WantExactErrorString("retryable revocation error: server responded with status 599"), wantErr: testutil.WantExactErrorString("retryable revocation error: server responded with status 599"),
@ -640,7 +641,7 @@ func TestProviderConfig(t *testing.T) {
}, },
{ {
name: "retryable error when the server cannot be reached", name: "retryable error when the server cannot be reached",
tokenType: provider.AccessTokenType, tokenType: upstreamprovider.AccessTokenType,
unreachableServer: true, unreachableServer: true,
wantErr: testutil.WantMatchingErrorString("^retryable revocation error: Post .*: dial tcp .*: connect: connection refused$"), wantErr: testutil.WantMatchingErrorString("^retryable revocation error: Post .*: dial tcp .*: connect: connection refused$"),
wantRetryableErrType: true, wantRetryableErrType: true,

View File

@ -33,7 +33,7 @@ import (
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/httputil/securityheader" "go.pinniped.dev/internal/httputil/securityheader"
"go.pinniped.dev/internal/net/phttp" "go.pinniped.dev/internal/net/phttp"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/upstreamoidc" "go.pinniped.dev/internal/upstreamoidc"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
@ -107,7 +107,7 @@ type handlerState struct {
getEnv func(key string) string getEnv func(key string) string
listen func(string, string) (net.Listener, error) listen func(string, string) (net.Listener, error)
isTTY func(int) bool isTTY func(int) bool
getProvider func(*oauth2.Config, *coreosoidc.Provider, *http.Client) provider.UpstreamOIDCIdentityProviderI getProvider func(*oauth2.Config, *coreosoidc.Provider, *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI
validateIDToken func(ctx context.Context, provider *coreosoidc.Provider, audience string, token string) (*coreosoidc.IDToken, error) validateIDToken func(ctx context.Context, provider *coreosoidc.Provider, audience string, token string) (*coreosoidc.IDToken, error)
promptForValue func(ctx context.Context, promptLabel string) (string, error) promptForValue func(ctx context.Context, promptLabel string) (string, error)
promptForSecret func(promptLabel string) (string, error) promptForSecret func(promptLabel string) (string, error)
@ -196,10 +196,11 @@ func WithSkipListen() Option {
// SessionCacheKey contains the data used to select a valid session cache entry. // SessionCacheKey contains the data used to select a valid session cache entry.
type SessionCacheKey struct { type SessionCacheKey struct {
Issuer string `json:"issuer"` Issuer string `json:"issuer"`
ClientID string `json:"clientID"` ClientID string `json:"clientID"`
Scopes []string `json:"scopes"` Scopes []string `json:"scopes"`
RedirectURI string `json:"redirect_uri"` RedirectURI string `json:"redirect_uri"`
UpstreamProviderName string `json:"upstream_provider_name,omitempty"`
} }
type SessionCache interface { type SessionCache interface {
@ -351,6 +352,10 @@ func (h *handlerState) baseLogin() (*oidctypes.Token, error) {
ClientID: h.clientID, ClientID: h.clientID,
Scopes: h.scopes, Scopes: h.scopes,
RedirectURI: (&url.URL{Scheme: "http", Host: h.listenAddr, Path: h.callbackPath}).String(), RedirectURI: (&url.URL{Scheme: "http", Host: h.listenAddr, Path: h.callbackPath}).String(),
// When using a Supervisor with multiple IDPs, the cache keys need to be different for each IDP
// so a user can have multiple sessions going for each IDP at the same time.
// When using a non-Supervisor OIDC provider, then this value will be blank, so it won't be part of the key.
UpstreamProviderName: h.upstreamIdentityProviderName,
} }
// If the ID token is still valid for a bit, return it immediately and skip the rest of the flow. // If the ID token is still valid for a bit, return it immediately and skip the rest of the flow.

View File

@ -1,4 +1,4 @@
// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2023 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package oidcclient package oidcclient
@ -32,7 +32,7 @@ import (
"go.pinniped.dev/internal/httputil/roundtripper" "go.pinniped.dev/internal/httputil/roundtripper"
"go.pinniped.dev/internal/mocks/mockupstreamoidcidentityprovider" "go.pinniped.dev/internal/mocks/mockupstreamoidcidentityprovider"
"go.pinniped.dev/internal/net/phttp" "go.pinniped.dev/internal/net/phttp"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider/upstreamprovider"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
"go.pinniped.dev/internal/testutil/testlogger" "go.pinniped.dev/internal/testutil/testlogger"
@ -504,7 +504,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return func(h *handlerState) error { return func(h *handlerState) error {
require.NoError(t, WithClient(newClientForServer(successServer))(h)) require.NoError(t, WithClient(newClientForServer(successServer))(h))
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true, false). ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true, false).
@ -553,7 +553,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return func(h *handlerState) error { return func(h *handlerState) error {
require.NoError(t, WithClient(newClientForServer(successServer))(h)) require.NoError(t, WithClient(newClientForServer(successServer))(h))
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true, false). ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true, false).
@ -1159,7 +1159,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
fmt.Sprintf("http://127.0.0.1:0/callback?code=%s&state=test-state", fakeAuthCode), fmt.Sprintf("http://127.0.0.1:0/callback?code=%s&state=test-state", fakeAuthCode),
}}, }},
}, nil) }, nil)
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ExchangeAuthcodeAndValidateTokens( ExchangeAuthcodeAndValidateTokens(
@ -1181,7 +1181,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return func(h *handlerState) error { return func(h *handlerState) error {
fakeAuthCode := "test-authcode-value" fakeAuthCode := "test-authcode-value"
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ExchangeAuthcodeAndValidateTokens( ExchangeAuthcodeAndValidateTokens(
@ -1281,7 +1281,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return func(h *handlerState) error { return func(h *handlerState) error {
fakeAuthCode := "test-authcode-value" fakeAuthCode := "test-authcode-value"
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ExchangeAuthcodeAndValidateTokens( ExchangeAuthcodeAndValidateTokens(
@ -1392,7 +1392,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
return func(h *handlerState) error { return func(h *handlerState) error {
fakeAuthCode := "test-authcode-value" fakeAuthCode := "test-authcode-value"
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ExchangeAuthcodeAndValidateTokens( ExchangeAuthcodeAndValidateTokens(
@ -1855,7 +1855,7 @@ func TestLogin(t *testing.T) { //nolint:gocyclo
}) })
h.cache = cache h.cache = cache
h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(config *oauth2.Config, provider *oidc.Provider, client *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true, false). ValidateTokenAndMergeWithUserInfo(gomock.Any(), HasAccessToken(testToken.AccessToken.Token), nonce.Nonce(""), true, false).
@ -1993,7 +1993,7 @@ func TestHandlePasteCallback(t *testing.T) {
return "invalid", nil return "invalid", nil
} }
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
@ -2017,7 +2017,7 @@ func TestHandlePasteCallback(t *testing.T) {
return "valid", nil return "valid", nil
} }
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
@ -2237,7 +2237,7 @@ func TestHandleAuthCodeCallback(t *testing.T) {
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "invalid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
@ -2256,7 +2256,7 @@ func TestHandleAuthCodeCallback(t *testing.T) {
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
return func(h *handlerState) error { return func(h *handlerState) error {
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
@ -2282,7 +2282,7 @@ func TestHandleAuthCodeCallback(t *testing.T) {
return func(h *handlerState) error { return func(h *handlerState) error {
h.useFormPost = true h.useFormPost = true
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).
@ -2311,7 +2311,7 @@ func TestHandleAuthCodeCallback(t *testing.T) {
return func(h *handlerState) error { return func(h *handlerState) error {
h.useFormPost = true h.useFormPost = true
h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI} h.oauth2Config = &oauth2.Config{RedirectURL: testRedirectURI}
h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) provider.UpstreamOIDCIdentityProviderI { h.getProvider = func(_ *oauth2.Config, _ *oidc.Provider, _ *http.Client) upstreamprovider.UpstreamOIDCIdentityProviderI {
mock := mockUpstream(t) mock := mockUpstream(t)
mock.EXPECT(). mock.EXPECT().
ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI). ExchangeAuthcodeAndValidateTokens(gomock.Any(), "valid", pkce.Code("test-pkce"), nonce.Nonce("test-nonce"), testRedirectURI).