Integration test + making sure we get the session correctly in token handler

This commit is contained in:
Margo Crawford 2022-01-19 13:20:49 -08:00
parent b0ea7063c7
commit 38d184fe81
2 changed files with 25 additions and 4 deletions

View File

@ -67,10 +67,14 @@ func NewHandler(
// When we are in the authorization code flow, check if we have any warnings that previous handlers want us // When we are in the authorization code flow, check if we have any warnings that previous handlers want us
// to send to the client to be printed on the CLI. // to send to the client to be printed on the CLI.
if accessRequest.GetGrantTypes().ExactOne("authorization_code") { if accessRequest.GetGrantTypes().ExactOne("authorization_code") {
for _, warningText := range session.Custom.Warnings { storedSession := accessRequest.GetSession().(*psession.PinnipedSession)
customSessionData := storedSession.Custom
if customSessionData != nil {
for _, warningText := range customSessionData.Warnings {
warning.AddWarning(r.Context(), "", warningText) warning.AddWarning(r.Context(), "", warningText)
} }
} }
}
accessResponse, err := oauthHelper.NewAccessResponse(r.Context(), accessRequest) accessResponse, err := oauthHelper.NewAccessResponse(r.Context(), accessRequest)
if err != nil { if err != nil {

View File

@ -1,4 +1,4 @@
// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved. // Copyright 2020-2022 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package integration package integration
@ -290,6 +290,20 @@ func TestE2EFullIntegration(t *testing.T) { // nolint:gocyclo
Resource: "namespaces", Resource: "namespaces",
}) })
var additionalScopes []string
// If we're using dex, we will test that we see a warning when the access token
// lifetime is too short (we have it set to 20 minutes) and it's using access token based refresh.
// To ensure that access token refresh happens rather than refresh token, don't ask for the offline_access scope.
// In other environments, test the refresh token based flow.
if len(env.ToolsNamespace) == 0 {
additionalScopes = env.SupervisorUpstreamOIDC.AdditionalScopes
} else {
for _, additionalScope := range env.SupervisorUpstreamOIDC.AdditionalScopes {
if additionalScope != "offline_access" {
additionalScopes = append(additionalScopes, additionalScope)
}
}
}
// Create upstream OIDC provider and wait for it to become ready. // Create upstream OIDC provider and wait for it to become ready.
testlib.CreateTestOIDCIdentityProvider(t, idpv1alpha1.OIDCIdentityProviderSpec{ testlib.CreateTestOIDCIdentityProvider(t, idpv1alpha1.OIDCIdentityProviderSpec{
Issuer: env.SupervisorUpstreamOIDC.Issuer, Issuer: env.SupervisorUpstreamOIDC.Issuer,
@ -297,7 +311,7 @@ func TestE2EFullIntegration(t *testing.T) { // nolint:gocyclo
CertificateAuthorityData: base64.StdEncoding.EncodeToString([]byte(env.SupervisorUpstreamOIDC.CABundle)), CertificateAuthorityData: base64.StdEncoding.EncodeToString([]byte(env.SupervisorUpstreamOIDC.CABundle)),
}, },
AuthorizationConfig: idpv1alpha1.OIDCAuthorizationConfig{ AuthorizationConfig: idpv1alpha1.OIDCAuthorizationConfig{
AdditionalScopes: env.SupervisorUpstreamOIDC.AdditionalScopes, AdditionalScopes: additionalScopes,
}, },
Claims: idpv1alpha1.OIDCClaims{ Claims: idpv1alpha1.OIDCClaims{
Username: env.SupervisorUpstreamOIDC.UsernameClaim, Username: env.SupervisorUpstreamOIDC.UsernameClaim,
@ -369,6 +383,9 @@ func TestE2EFullIntegration(t *testing.T) { // nolint:gocyclo
// Ignore any errors returned because there is always an error on linux. // Ignore any errors returned because there is always an error on linux.
kubectlOutputBytes, _ := ioutil.ReadAll(ptyFile) kubectlOutputBytes, _ := ioutil.ReadAll(ptyFile)
requireKubectlGetNamespaceOutput(t, env, string(kubectlOutputBytes)) requireKubectlGetNamespaceOutput(t, env, string(kubectlOutputBytes))
if len(env.ToolsNamespace) > 0 {
require.Contains(t, string(kubectlOutputBytes), "Access token from identity provider has lifetime of less than 3 hours. Expect frequent prompts to log in.")
}
t.Logf("first kubectl command took %s", time.Since(start).String()) t.Logf("first kubectl command took %s", time.Since(start).String())