callback_handler.go: start happy path test with redirect

Next steps: fosite storage?

Signed-off-by: Ryan Richard <richardry@vmware.com>
This commit is contained in:
Andrew Keesler 2020-11-16 17:07:34 -05:00 committed by Ryan Richard
parent 052cdc40dc
commit 1c7601a2b5
No known key found for this signature in database
GPG Key ID: 27CE0444346F9413
2 changed files with 84 additions and 15 deletions

View File

@ -5,7 +5,9 @@
package callback package callback
import ( import (
"fmt"
"net/http" "net/http"
"net/url"
"path" "path"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
@ -20,7 +22,8 @@ func NewHandler(
stateDecoder, cookieDecoder oidc.Decoder, stateDecoder, cookieDecoder oidc.Decoder,
) http.Handler { ) http.Handler {
return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
if err := validateRequest(r, stateDecoder, cookieDecoder); err != nil { state, err := validateRequest(r, stateDecoder, cookieDecoder)
if err != nil {
return err return err
} }
@ -29,43 +32,56 @@ func NewHandler(
return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found") return httperr.New(http.StatusUnprocessableEntity, "upstream provider not found")
} }
downstreamAuthParams, err := url.ParseQuery(state.AuthParams)
if err != nil {
panic(err)
}
downstreamCallbackURL := fmt.Sprintf(
"%s?code=%s&state=%s",
downstreamAuthParams.Get("redirect_uri"),
url.QueryEscape("some-code"),
url.QueryEscape(downstreamAuthParams.Get("state")),
)
http.Redirect(w, r, downstreamCallbackURL, 302)
return nil return nil
}) })
} }
func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) error { func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) {
if r.Method != http.MethodGet { if r.Method != http.MethodGet {
return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) return nil, httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method)
} }
csrfValue, err := readCSRFCookie(r, cookieDecoder) csrfValue, err := readCSRFCookie(r, cookieDecoder)
if err != nil { if err != nil {
plog.InfoErr("error reading CSRF cookie", err) plog.InfoErr("error reading CSRF cookie", err)
return err return nil, err
} }
if r.FormValue("code") == "" { if r.FormValue("code") == "" {
plog.Info("code param not found") plog.Info("code param not found")
return httperr.New(http.StatusBadRequest, "code param not found") return nil, httperr.New(http.StatusBadRequest, "code param not found")
} }
if r.FormValue("state") == "" { if r.FormValue("state") == "" {
plog.Info("state param not found") plog.Info("state param not found")
return httperr.New(http.StatusBadRequest, "state param not found") return nil, httperr.New(http.StatusBadRequest, "state param not found")
} }
state, err := readState(r, stateDecoder) state, err := readState(r, stateDecoder)
if err != nil { if err != nil {
plog.InfoErr("error reading state", err) plog.InfoErr("error reading state", err)
return err return nil, err
} }
if state.CSRFToken != csrfValue { if state.CSRFToken != csrfValue {
plog.InfoErr("CSRF value does not match", err) plog.InfoErr("CSRF value does not match", err)
return httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err) return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err)
} }
return nil return state, nil
} }
func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *provider.UpstreamOIDCIdentityProvider { func findUpstreamIDPConfig(r *http.Request, idpListGetter oidc.IDPListGetter) *provider.UpstreamOIDCIdentityProvider {

View File

@ -5,9 +5,11 @@ package callback
import ( import (
"fmt" "fmt"
"html"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"regexp"
"testing" "testing"
"github.com/gorilla/securecookie" "github.com/gorilla/securecookie"
@ -22,6 +24,10 @@ const (
) )
func TestCallbackEndpoint(t *testing.T) { func TestCallbackEndpoint(t *testing.T) {
const (
downstreamRedirectURI = "http://127.0.0.1/callback"
)
upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth") upstreamAuthURL, err := url.Parse("https://some-upstream-idp:8443/auth")
require.NoError(t, err) require.NoError(t, err)
otherUpstreamAuthURL, err := url.Parse("https://some-other-upstream-idp:8443/auth") otherUpstreamAuthURL, err := url.Parse("https://some-other-upstream-idp:8443/auth")
@ -53,13 +59,25 @@ func TestCallbackEndpoint(t *testing.T) {
var happyCookieCodec = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey) var happyCookieCodec = securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey)
happyCookieCodec.SetSerializer(securecookie.JSONEncoder{}) happyCookieCodec.SetSerializer(securecookie.JSONEncoder{})
happyDownstreamState := "some-downstream-state"
happyOrignalRequestParams := url.Values{
"response_type": []string{"code"},
"scope": []string{"openid profile email"},
"client_id": []string{"pinniped-cli"},
"state": []string{happyDownstreamState},
"nonce": []string{"some-nonce-value"},
"code_challenge": []string{"some-challenge"},
"code_challenge_method": []string{"S256"},
"redirect_uri": []string{downstreamRedirectURI},
}.Encode()
happyCSRF := "test-csrf" happyCSRF := "test-csrf"
happyPKCE := "test-pkce" happyPKCE := "test-pkce"
happyNonce := "test-nonce" happyNonce := "test-nonce"
happyState, err := happyStateCodec.Encode("s", happyState, err := happyStateCodec.Encode("s",
testutil.ExpectedUpstreamStateParamFormat{ testutil.ExpectedUpstreamStateParamFormat{
P: "todo query goes here", P: happyOrignalRequestParams,
N: happyNonce, N: happyNonce,
C: happyCSRF, C: happyCSRF,
K: happyPKCE, K: happyPKCE,
@ -70,7 +88,7 @@ func TestCallbackEndpoint(t *testing.T) {
wrongCSRFValueState, err := happyStateCodec.Encode("s", wrongCSRFValueState, err := happyStateCodec.Encode("s",
testutil.ExpectedUpstreamStateParamFormat{ testutil.ExpectedUpstreamStateParamFormat{
P: "todo query goes here", P: happyOrignalRequestParams,
N: happyNonce, N: happyNonce,
C: "wrong-csrf-value", C: "wrong-csrf-value",
K: happyPKCE, K: happyPKCE,
@ -81,7 +99,7 @@ func TestCallbackEndpoint(t *testing.T) {
wrongVersionState, err := happyStateCodec.Encode("s", wrongVersionState, err := happyStateCodec.Encode("s",
testutil.ExpectedUpstreamStateParamFormat{ testutil.ExpectedUpstreamStateParamFormat{
P: "todo query goes here", P: happyOrignalRequestParams,
N: happyNonce, N: happyNonce,
C: happyCSRF, C: happyCSRF,
K: happyPKCE, K: happyPKCE,
@ -102,11 +120,22 @@ func TestCallbackEndpoint(t *testing.T) {
path string path string
csrfCookie string csrfCookie string
wantStatus int wantStatus int
wantBody string wantBody string
wantRedirectLocationRegexp string
}{ }{
// Happy path // Happy path
// TODO: GET with good state and cookie and successful upstream token exchange and 302 to downstream client callback with its state and code // TODO: GET with good state and cookie and successful upstream token exchange and 302 to downstream client callback with its state and code
{
name: "GET with good state and cookie and successful upstream token exchange returns 302 to downstream client callback with its state and code",
idpListGetter: testutil.NewIDPListGetter(upstreamOIDCIdentityProvider),
method: http.MethodGet,
path: newRequestPath().WithState(happyState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusFound,
wantRedirectLocationRegexp: downstreamRedirectURI + `\?code=([^&]+)&state=` + happyDownstreamState,
},
// TODO: when we call the callback twice in a row, we get two different auth codes (to prove we are using an RNG for auth codes)
// Pre-upstream-exchange verification // Pre-upstream-exchange verification
{ {
@ -240,7 +269,31 @@ func TestCallbackEndpoint(t *testing.T) {
subject.ServeHTTP(rsp, req) subject.ServeHTTP(rsp, req)
require.Equal(t, test.wantStatus, rsp.Code) require.Equal(t, test.wantStatus, rsp.Code)
require.Equal(t, test.wantBody, rsp.Body.String())
require.False(t, test.wantBody != "" && test.wantRedirectLocationRegexp != "", "test cannot set both body and redirect assertions")
switch {
case test.wantBody != "":
require.Empty(t, rsp.Header().Values("Location"))
require.Equal(t, test.wantBody, rsp.Body.String())
case test.wantRedirectLocationRegexp != "":
// Assert that Location header matches regular expression.
require.Len(t, rsp.Header().Values("Location"), 1)
actualLocation := rsp.Header().Get("Location")
regex := regexp.MustCompile(test.wantRedirectLocationRegexp)
submatches := regex.FindStringSubmatch(actualLocation)
require.Lenf(t, submatches, 2, "no regexp match in actualLocation: %q", actualLocation)
capturedAuthCode := submatches[1]
_ = capturedAuthCode
// Assert capturedAuthCode storage stuff...
// Assert that body contains anchor tag with redirect location.
anchorTagWithLocationHref := fmt.Sprintf("<a href=\"%s\">Found</a>.\n\n", html.EscapeString(actualLocation))
require.Equal(t, anchorTagWithLocationHref, rsp.Body.String())
default:
require.Empty(t, rsp.Header().Values("Location"))
require.Empty(t, rsp.Body.String())
}
}) })
} }
} }