diff --git a/internal/oidc/auth/auth_handler_test.go b/internal/oidc/auth/auth_handler_test.go
index b9616207..fe38f2ab 100644
--- a/internal/oidc/auth/auth_handler_test.go
+++ b/internal/oidc/auth/auth_handler_test.go
@@ -248,27 +248,25 @@ func TestAuthorizationEndpoint(t *testing.T) {
wantCSRFCookieHeader string
wantUpstreamStateParamInLocationHeader bool
+ wantBodyStringWithLocationInHref bool
}
tests := []testCase{
{
- name: "happy path using GET",
- issuer: issuer,
- idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
- generateCSRF: happyCSRFGenerator,
- generatePKCE: happyPKCEGenerator,
- generateNonce: happyNonceGenerator,
- encoder: happyEncoder,
- method: http.MethodGet,
- path: happyGetRequestPath,
- wantStatus: http.StatusFound,
- wantContentType: "text/html; charset=utf-8",
- wantBodyString: fmt.Sprintf(`Found.%s`,
- html.EscapeString(expectedRedirectLocation(expectedUpstreamStateParam(nil))),
- "\n\n",
- ),
+ name: "happy path using GET",
+ issuer: issuer,
+ idpListGetter: newIDPListGetter(upstreamOIDCIdentityProvider),
+ generateCSRF: happyCSRFGenerator,
+ generatePKCE: happyPKCEGenerator,
+ generateNonce: happyNonceGenerator,
+ encoder: happyEncoder,
+ method: http.MethodGet,
+ path: happyGetRequestPath,
+ wantStatus: http.StatusFound,
+ wantContentType: "text/html; charset=utf-8",
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(nil)),
wantUpstreamStateParamInLocationHeader: true,
+ wantBodyStringWithLocationInHref: true,
},
{
name: "happy path using POST",
@@ -301,19 +299,14 @@ func TestAuthorizationEndpoint(t *testing.T) {
path: modifiedHappyGetRequestPath(map[string]string{
"redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client
}),
- wantStatus: http.StatusFound,
- wantContentType: "text/html; charset=utf-8",
- wantBodyString: fmt.Sprintf(`Found.%s`,
- html.EscapeString(expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{
- "redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client
- }))),
- "\n\n",
- ),
+ wantStatus: http.StatusFound,
+ wantContentType: "text/html; charset=utf-8",
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{
"redirect_uri": downstreamRedirectURIWithDifferentPort, // not the same port number that is registered for the client
})),
wantUpstreamStateParamInLocationHeader: true,
+ wantBodyStringWithLocationInHref: true,
},
{
name: "downstream redirect uri does not match what is configured for client",
@@ -491,16 +484,13 @@ func TestAuthorizationEndpoint(t *testing.T) {
encoder: happyEncoder,
method: http.MethodGet,
// The following prompt value is illegal when openid is requested, but note that openid is not requested.
- path: modifiedHappyGetRequestPath(map[string]string{"prompt": "none login", "scope": "email"}),
- wantStatus: http.StatusFound,
- wantContentType: "text/html; charset=utf-8",
- wantBodyString: fmt.Sprintf(`Found.%s`,
- html.EscapeString(expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{"prompt": "none login", "scope": "email"}))),
- "\n\n",
- ),
+ path: modifiedHappyGetRequestPath(map[string]string{"prompt": "none login", "scope": "email"}),
+ wantStatus: http.StatusFound,
+ wantContentType: "text/html; charset=utf-8",
wantCSRFCookieHeader: happyCSRFSetCookieHeaderValue,
wantLocationHeader: expectedRedirectLocation(expectedUpstreamStateParam(map[string]string{"prompt": "none login", "scope": "email"})),
wantUpstreamStateParamInLocationHeader: true,
+ wantBodyStringWithLocationInHref: true,
},
{
name: "state does not have enough entropy",
@@ -634,18 +624,24 @@ func TestAuthorizationEndpoint(t *testing.T) {
require.Equal(t, test.wantStatus, rsp.Code)
requireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType)
+ actualLocation := rsp.Header().Get("Location")
if test.wantLocationHeader != "" {
- actualLocation := rsp.Header().Get("Location")
if test.wantUpstreamStateParamInLocationHeader {
requireEqualDecodedStateParams(t, actualLocation, test.wantLocationHeader, test.encoder)
}
- requireEqualURLs(t, actualLocation, test.wantLocationHeader)
+ // The upstream state param is encoded using a timestamp at the beginning so we don't want to
+ // compare those states since they may be different, but we do want to compare the downstream
+ // state param that should be exactly the same.
+ requireEqualURLs(t, actualLocation, test.wantLocationHeader, test.wantUpstreamStateParamInLocationHeader)
} else {
require.Empty(t, rsp.Header().Values("Location"))
}
if test.wantBodyJSON != "" {
require.JSONEq(t, test.wantBodyJSON, rsp.Body.String())
+ } else if test.wantBodyStringWithLocationInHref {
+ anchorTagWithLocationHref := fmt.Sprintf("Found.\n\n", html.EscapeString(actualLocation))
+ require.Equal(t, anchorTagWithLocationHref, rsp.Body.String())
} else {
require.Equal(t, test.wantBodyString, rsp.Body.String())
}
@@ -769,7 +765,7 @@ func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL
require.Equal(t, expectedDecodedStateParam, actualDecodedStateParam)
}
-func requireEqualURLs(t *testing.T, actualURL string, expectedURL string) {
+func requireEqualURLs(t *testing.T, actualURL string, expectedURL string, ignoreState bool) {
t.Helper()
actualLocationURL, err := url.Parse(actualURL)
require.NoError(t, err)
@@ -779,7 +775,16 @@ func requireEqualURLs(t *testing.T, actualURL string, expectedURL string) {
require.Equal(t, expectedLocationURL.User, actualLocationURL.User)
require.Equal(t, expectedLocationURL.Host, actualLocationURL.Host)
require.Equal(t, expectedLocationURL.Path, actualLocationURL.Path)
- require.Equal(t, expectedLocationURL.Query(), actualLocationURL.Query())
+
+ expectedLocationQuery := expectedLocationURL.Query()
+ actualLocationQuery := actualLocationURL.Query()
+ // Let the caller ignore the state, since it may contain a digest at the end that is difficult to
+ // predict because it depends on a time.Now() timestamp.
+ if ignoreState {
+ expectedLocationQuery.Del("state")
+ actualLocationQuery.Del("state")
+ }
+ require.Equal(t, expectedLocationQuery, actualLocationQuery)
}
func newIDPListGetter(upstreamOIDCIdentityProviders ...provider.UpstreamOIDCIdentityProvider) provider.DynamicUpstreamIDPProvider {