Implement login_handler.go to defer to other handlers

The other handlers for GET and POST requests are not yet implemented in
this commit. The shared handler code in login_handler.go takes care of
things checking the method, checking the CSRF cookie, decoding the state
param, and adding security headers on behalf of both the GET and POST
handlers.

Some code has been extracted from callback_handler.go to be shared.
This commit is contained in:
Ryan Richard 2022-04-26 15:30:39 -07:00
parent eb1d3812ec
commit 65eed7e742
9 changed files with 656 additions and 99 deletions

View File

@ -5,7 +5,6 @@
package callback package callback
import ( import (
"crypto/subtle"
"net/http" "net/http"
"net/url" "net/url"
@ -14,7 +13,6 @@ import (
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/httputil/securityheader" "go.pinniped.dev/internal/httputil/securityheader"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/downstreamsession" "go.pinniped.dev/internal/oidc/downstreamsession"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/oidc/provider/formposthtml" "go.pinniped.dev/internal/oidc/provider/formposthtml"
@ -102,9 +100,9 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder)
return nil, httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method) return nil, httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET)", r.Method)
} }
csrfValue, err := readCSRFCookie(r, cookieDecoder) _, decodedState, err := oidc.ReadStateParamAndValidateCSRFCookie(r, cookieDecoder, stateDecoder)
if err != nil { if err != nil {
plog.InfoErr("error reading CSRF cookie", err) plog.InfoErr("state or CSRF error", err)
return nil, err return nil, err
} }
@ -113,23 +111,7 @@ func validateRequest(r *http.Request, stateDecoder, cookieDecoder oidc.Decoder)
return nil, httperr.New(http.StatusBadRequest, "code param not found") return nil, httperr.New(http.StatusBadRequest, "code param not found")
} }
if r.FormValue("state") == "" { return decodedState, nil
plog.Info("state param not found")
return nil, httperr.New(http.StatusBadRequest, "state param not found")
}
state, err := readState(r, stateDecoder)
if err != nil {
plog.InfoErr("error reading state", err)
return nil, err
}
if subtle.ConstantTimeCompare([]byte(state.CSRFToken), []byte(csrfValue)) != 1 {
plog.InfoErr("CSRF value does not match", err)
return nil, httperr.Wrap(http.StatusForbidden, "CSRF value does not match", err)
}
return state, nil
} }
func findUpstreamIDPConfig(upstreamName string, upstreamIDPs oidc.UpstreamOIDCIdentityProvidersLister) provider.UpstreamOIDCIdentityProviderI { func findUpstreamIDPConfig(upstreamName string, upstreamIDPs oidc.UpstreamOIDCIdentityProvidersLister) provider.UpstreamOIDCIdentityProviderI {
@ -140,36 +122,3 @@ func findUpstreamIDPConfig(upstreamName string, upstreamIDPs oidc.UpstreamOIDCId
} }
return nil return nil
} }
func readCSRFCookie(r *http.Request, cookieDecoder oidc.Decoder) (csrftoken.CSRFToken, error) {
receivedCSRFCookie, err := r.Cookie(oidc.CSRFCookieName)
if err != nil {
// Error means that the cookie was not found
return "", httperr.Wrap(http.StatusForbidden, "CSRF cookie is missing", err)
}
var csrfFromCookie csrftoken.CSRFToken
err = cookieDecoder.Decode(oidc.CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie)
if err != nil {
return "", httperr.Wrap(http.StatusForbidden, "error reading CSRF cookie", err)
}
return csrfFromCookie, nil
}
func readState(r *http.Request, stateDecoder oidc.Decoder) (*oidc.UpstreamStateParamData, error) {
var state oidc.UpstreamStateParamData
if err := stateDecoder.Decode(
oidc.UpstreamStateParamEncodingName,
r.FormValue("state"),
&state,
); err != nil {
return nil, httperr.New(http.StatusBadRequest, "error reading state")
}
if state.FormatVersion != oidc.UpstreamStateParamFormatVersion {
return nil, httperr.New(http.StatusUnprocessableEntity, "state format version is invalid")
}
return &state, nil
}

View File

@ -1156,10 +1156,8 @@ func (r *requestPath) String() string {
return path + params.Encode() return path + params.Encode()
} }
type upstreamStateParamBuilder oidctestutil.ExpectedUpstreamStateParamFormat func happyUpstreamStateParam() *oidctestutil.UpstreamStateParamBuilder {
return &oidctestutil.UpstreamStateParamBuilder{
func happyUpstreamStateParam() *upstreamStateParamBuilder {
return &upstreamStateParamBuilder{
U: happyUpstreamIDPName, U: happyUpstreamIDPName,
P: happyDownstreamRequestParams, P: happyDownstreamRequestParams,
T: "oidc", T: "oidc",
@ -1170,37 +1168,6 @@ func happyUpstreamStateParam() *upstreamStateParamBuilder {
} }
} }
func (b upstreamStateParamBuilder) Build(t *testing.T, stateEncoder *securecookie.SecureCookie) string {
state, err := stateEncoder.Encode("s", b)
require.NoError(t, err)
return state
}
func (b *upstreamStateParamBuilder) WithAuthorizeRequestParams(params string) *upstreamStateParamBuilder {
b.P = params
return b
}
func (b *upstreamStateParamBuilder) WithNonce(nonce string) *upstreamStateParamBuilder {
b.N = nonce
return b
}
func (b *upstreamStateParamBuilder) WithCSRF(csrf string) *upstreamStateParamBuilder {
b.C = csrf
return b
}
func (b *upstreamStateParamBuilder) WithPKCVE(pkce string) *upstreamStateParamBuilder {
b.K = pkce
return b
}
func (b *upstreamStateParamBuilder) WithStateVersion(version string) *upstreamStateParamBuilder {
b.V = version
return b
}
func happyUpstream() *oidctestutil.TestUpstreamOIDCIdentityProviderBuilder { func happyUpstream() *oidctestutil.TestUpstreamOIDCIdentityProviderBuilder {
return oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder(). return oidctestutil.NewTestUpstreamOIDCIdentityProviderBuilder().
WithName(happyUpstreamIDPName). WithName(happyUpstreamIDPName).

View File

@ -0,0 +1,17 @@
// Copyright 2022 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package login
import (
"net/http"
"go.pinniped.dev/internal/oidc"
)
func NewGetHandler(upstreamIDPs oidc.UpstreamIdentityProvidersLister) HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, encodedState string, decodedState *oidc.UpstreamStateParamData) error {
// TODO
return nil
}
}

View File

@ -5,19 +5,64 @@ package login
import ( import (
"net/http" "net/http"
idpdiscoveryv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idpdiscovery/v1alpha1"
"go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/httputil/securityheader"
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/plog"
) )
// NewHandler returns an http.Handler that serves the login endpoint for IDPs that // HandlerFunc is a function that can handle either a GET or POST request for the login endpoint.
// don't have their own Web UI. type HandlerFunc func(
func NewHandler() http.Handler { w http.ResponseWriter,
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { r *http.Request,
if r.Method != http.MethodGet { encodedState string,
http.Error(w, `Method not allowed (try GET)`, http.StatusMethodNotAllowed) decodedState *oidc.UpstreamStateParamData,
return ) error
// NewHandler returns a http.Handler that serves the login endpoint for IDPs that don't have their own web UI for login.
//
// This handler takes care of the shared concerns between the GET and POST methods of the login endpoint:
// checking the method, checking the CSRF cookie, decoding the state param, and adding security headers.
// Then it defers the rest of the handling to the passed in handler functions for GET and POST requests.
// Note that CSRF protection isn't needed on GET requests, but it doesn't hurt. Putting it here
// keeps the implementations and tests of HandlerFunc simpler since they won't need to deal with any decoders.
// Users should always initially get redirected to this page from the authorization endpoint, and never need
// to navigate directly to this page in their browser without going through the authorization endpoint first.
// Once their browser has landed on this page, it should be okay for the user to refresh the browser.
func NewHandler(
stateDecoder oidc.Decoder,
cookieDecoder oidc.Decoder,
getHandler HandlerFunc, // use NewGetHandler() for production
postHandler HandlerFunc, // use NewPostHandler() for production
) http.Handler {
loginHandler := httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
var handler HandlerFunc
switch r.Method {
case http.MethodGet:
handler = getHandler
case http.MethodPost:
handler = postHandler
default:
return httperr.Newf(http.StatusMethodNotAllowed, "%s (try GET or POST)", r.Method)
} }
_, err := w.Write([]byte("<p>hello world</p>"))
encodedState, decodedState, err := oidc.ReadStateParamAndValidateCSRFCookie(r, cookieDecoder, stateDecoder)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) plog.InfoErr("state or CSRF error", err)
return err
} }
switch decodedState.UpstreamType {
case string(idpdiscoveryv1alpha1.IDPTypeLDAP), string(idpdiscoveryv1alpha1.IDPTypeActiveDirectory):
// these are the types supported by this endpoint, so no error here
default:
return httperr.Newf(http.StatusBadRequest, "not a supported upstream IDP type for this endpoint: %q", decodedState.UpstreamType)
}
return handler(w, r, encodedState, decodedState)
}) })
return securityheader.Wrap(loginHandler)
} }

View File

@ -0,0 +1,448 @@
// Copyright 2022 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package login
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/gorilla/securecookie"
"github.com/stretchr/testify/require"
"go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/testutil"
"go.pinniped.dev/internal/testutil/oidctestutil"
)
func TestLoginEndpoint(t *testing.T) {
const (
htmlContentType = "text/html; charset=utf-8"
happyGetResult = "<p>get handler result</p>"
happyPostResult = "<p>post handler result</p>"
happyUpstreamIDPName = "upstream-idp-name"
happyUpstreamIDPType = "ldap"
happyDownstreamCSRF = "test-csrf"
happyDownstreamPKCE = "test-pkce"
happyDownstreamNonce = "test-nonce"
happyDownstreamStateVersion = "2"
downstreamClientID = "pinniped-cli"
happyDownstreamState = "8b-state"
downstreamNonce = "some-nonce-value"
downstreamPKCEChallenge = "some-challenge"
downstreamPKCEChallengeMethod = "S256"
downstreamRedirectURI = "http://127.0.0.1/callback"
)
happyDownstreamScopesRequested := []string{"openid"}
happyDownstreamRequestParamsQuery := url.Values{
"response_type": []string{"code"},
"scope": []string{strings.Join(happyDownstreamScopesRequested, " ")},
"client_id": []string{downstreamClientID},
"state": []string{happyDownstreamState},
"nonce": []string{downstreamNonce},
"code_challenge": []string{downstreamPKCEChallenge},
"code_challenge_method": []string{downstreamPKCEChallengeMethod},
"redirect_uri": []string{downstreamRedirectURI},
}
happyDownstreamRequestParams := happyDownstreamRequestParamsQuery.Encode()
expectedHappyDecodedUpstreamStateParam := func() *oidc.UpstreamStateParamData {
return &oidc.UpstreamStateParamData{
UpstreamName: happyUpstreamIDPName,
UpstreamType: happyUpstreamIDPType,
AuthParams: happyDownstreamRequestParams,
Nonce: happyDownstreamNonce,
CSRFToken: happyDownstreamCSRF,
PKCECode: happyDownstreamPKCE,
FormatVersion: happyDownstreamStateVersion,
}
}
expectedHappyDecodedUpstreamStateParamForActiveDirectory := func() *oidc.UpstreamStateParamData {
s := expectedHappyDecodedUpstreamStateParam()
s.UpstreamType = "activedirectory"
return s
}
happyUpstreamStateParam := func() *oidctestutil.UpstreamStateParamBuilder {
return &oidctestutil.UpstreamStateParamBuilder{
U: happyUpstreamIDPName,
T: happyUpstreamIDPType,
P: happyDownstreamRequestParams,
N: happyDownstreamNonce,
C: happyDownstreamCSRF,
K: happyDownstreamPKCE,
V: happyDownstreamStateVersion,
}
}
stateEncoderHashKey := []byte("fake-hash-secret")
stateEncoderBlockKey := []byte("0123456789ABCDEF") // block encryption requires 16/24/32 bytes for AES
cookieEncoderHashKey := []byte("fake-hash-secret2")
cookieEncoderBlockKey := []byte("0123456789ABCDE2") // block encryption requires 16/24/32 bytes for AES
require.NotEqual(t, stateEncoderHashKey, cookieEncoderHashKey)
require.NotEqual(t, stateEncoderBlockKey, cookieEncoderBlockKey)
happyStateCodec := securecookie.New(stateEncoderHashKey, stateEncoderBlockKey)
happyStateCodec.SetSerializer(securecookie.JSONEncoder{})
happyCookieCodec := securecookie.New(cookieEncoderHashKey, cookieEncoderBlockKey)
happyCookieCodec.SetSerializer(securecookie.JSONEncoder{})
happyState := happyUpstreamStateParam().Build(t, happyStateCodec)
happyPathWithState := newRequestPath().WithState(happyState).String()
happyActiveDirectoryState := happyUpstreamStateParam().WithUpstreamIDPType("activedirectory").Build(t, happyStateCodec)
encodedIncomingCookieCSRFValue, err := happyCookieCodec.Encode("csrf", happyDownstreamCSRF)
require.NoError(t, err)
happyCSRFCookie := "__Host-pinniped-csrf=" + encodedIncomingCookieCSRFValue
tests := []struct {
name string
method string
path string
csrfCookie string
getHandlerErr error
postHandlerErr error
wantStatus int
wantContentType string
wantBody string
wantEncodedState string
wantDecodedState *oidc.UpstreamStateParamData
}{
{
name: "PUT method is invalid",
method: http.MethodPut,
path: happyPathWithState,
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusMethodNotAllowed,
wantContentType: htmlContentType,
wantBody: "Method Not Allowed: PUT (try GET or POST)\n",
},
{
name: "PATCH method is invalid",
method: http.MethodPatch,
path: happyPathWithState,
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusMethodNotAllowed,
wantContentType: htmlContentType,
wantBody: "Method Not Allowed: PATCH (try GET or POST)\n",
},
{
name: "DELETE method is invalid",
method: http.MethodDelete,
path: happyPathWithState,
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusMethodNotAllowed,
wantContentType: htmlContentType,
wantBody: "Method Not Allowed: DELETE (try GET or POST)\n",
},
{
name: "HEAD method is invalid",
method: http.MethodHead,
path: happyPathWithState,
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusMethodNotAllowed,
wantContentType: htmlContentType,
wantBody: "Method Not Allowed: HEAD (try GET or POST)\n",
},
{
name: "CONNECT method is invalid",
method: http.MethodConnect,
path: happyPathWithState,
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusMethodNotAllowed,
wantContentType: htmlContentType,
wantBody: "Method Not Allowed: CONNECT (try GET or POST)\n",
},
{
name: "OPTIONS method is invalid",
method: http.MethodOptions,
path: happyPathWithState,
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusMethodNotAllowed,
wantContentType: htmlContentType,
wantBody: "Method Not Allowed: OPTIONS (try GET or POST)\n",
},
{
name: "TRACE method is invalid",
method: http.MethodTrace,
path: happyPathWithState,
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusMethodNotAllowed,
wantContentType: htmlContentType,
wantBody: "Method Not Allowed: TRACE (try GET or POST)\n",
},
{
name: "state param was not included on GET request",
method: http.MethodGet,
path: newRequestPath().WithoutState().String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest,
wantContentType: htmlContentType,
wantBody: "Bad Request: state param not found\n",
},
{
name: "state param was not included on POST request",
method: http.MethodPost,
path: newRequestPath().WithoutState().String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest,
wantContentType: htmlContentType,
wantBody: "Bad Request: state param not found\n",
},
{
name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason on GET request",
method: http.MethodGet,
path: newRequestPath().WithState("this-will-not-decode").String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest,
wantContentType: htmlContentType,
wantBody: "Bad Request: error reading state\n",
},
{
name: "state param was not signed correctly, has expired, or otherwise cannot be decoded for any reason on POST request",
method: http.MethodPost,
path: newRequestPath().WithState("this-will-not-decode").String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest,
wantContentType: htmlContentType,
wantBody: "Bad Request: error reading state\n",
},
{
name: "the CSRF cookie does not exist on GET request",
method: http.MethodGet,
path: happyPathWithState,
csrfCookie: "",
wantStatus: http.StatusForbidden,
wantContentType: htmlContentType,
wantBody: "Forbidden: CSRF cookie is missing\n",
},
{
name: "the CSRF cookie does not exist on POST request",
method: http.MethodPost,
path: happyPathWithState,
csrfCookie: "",
wantStatus: http.StatusForbidden,
wantContentType: htmlContentType,
wantBody: "Forbidden: CSRF cookie is missing\n",
},
{
name: "the CSRF cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason on GET request",
method: http.MethodGet,
path: happyPathWithState,
csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped",
wantStatus: http.StatusForbidden,
wantContentType: htmlContentType,
wantBody: "Forbidden: error reading CSRF cookie\n",
},
{
name: "the CSRF cookie was not signed correctly, has expired, or otherwise cannot be decoded for any reason on POST request",
method: http.MethodPost,
path: happyPathWithState,
csrfCookie: "__Host-pinniped-csrf=this-value-was-not-signed-by-pinniped",
wantStatus: http.StatusForbidden,
wantContentType: htmlContentType,
wantBody: "Forbidden: error reading CSRF cookie\n",
},
{
name: "cookie csrf value does not match state csrf value on GET request",
method: http.MethodGet,
path: newRequestPath().WithState(happyUpstreamStateParam().WithCSRF("wrong-csrf-value").Build(t, happyStateCodec)).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusForbidden,
wantContentType: htmlContentType,
wantBody: "Forbidden: CSRF value does not match\n",
},
{
name: "cookie csrf value does not match state csrf value on POST request",
method: http.MethodPost,
path: newRequestPath().WithState(happyUpstreamStateParam().WithCSRF("wrong-csrf-value").Build(t, happyStateCodec)).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusForbidden,
wantContentType: htmlContentType,
wantBody: "Forbidden: CSRF value does not match\n",
},
{
name: "GET request when upstream IDP type in state param is not supported by this endpoint",
method: http.MethodGet,
path: newRequestPath().WithState(
happyUpstreamStateParam().WithUpstreamIDPType("oidc").Build(t, happyStateCodec),
).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest,
wantContentType: htmlContentType,
wantBody: "Bad Request: not a supported upstream IDP type for this endpoint: \"oidc\"\n",
},
{
name: "POST request when upstream IDP type in state param is not supported by this endpoint",
method: http.MethodPost,
path: newRequestPath().WithState(
happyUpstreamStateParam().WithUpstreamIDPType("oidc").Build(t, happyStateCodec),
).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusBadRequest,
wantContentType: htmlContentType,
wantBody: "Bad Request: not a supported upstream IDP type for this endpoint: \"oidc\"\n",
},
{
name: "valid GET request when GET endpoint handler returns an error",
method: http.MethodGet,
path: happyPathWithState,
csrfCookie: happyCSRFCookie,
getHandlerErr: httperr.Newf(http.StatusInternalServerError, "some get error"),
wantStatus: http.StatusInternalServerError,
wantContentType: htmlContentType,
wantBody: "Internal Server Error: some get error\n",
wantEncodedState: happyState,
wantDecodedState: expectedHappyDecodedUpstreamStateParam(),
},
{
name: "valid POST request when POST endpoint handler returns an error",
method: http.MethodPost,
path: happyPathWithState,
csrfCookie: happyCSRFCookie,
postHandlerErr: httperr.Newf(http.StatusInternalServerError, "some post error"),
wantStatus: http.StatusInternalServerError,
wantContentType: htmlContentType,
wantBody: "Internal Server Error: some post error\n",
wantEncodedState: happyState,
wantDecodedState: expectedHappyDecodedUpstreamStateParam(),
},
{
name: "happy GET request for LDAP upstream",
method: http.MethodGet,
path: happyPathWithState,
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusOK,
wantContentType: htmlContentType,
wantBody: happyGetResult,
wantEncodedState: happyState,
wantDecodedState: expectedHappyDecodedUpstreamStateParam(),
},
{
name: "happy POST request for LDAP upstream",
method: http.MethodPost,
path: happyPathWithState,
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusOK,
wantContentType: htmlContentType,
wantBody: happyPostResult,
wantEncodedState: happyState,
wantDecodedState: expectedHappyDecodedUpstreamStateParam(),
},
{
name: "happy GET request for ActiveDirectory upstream",
method: http.MethodGet,
path: newRequestPath().WithState(happyActiveDirectoryState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusOK,
wantContentType: htmlContentType,
wantBody: happyGetResult,
wantEncodedState: happyActiveDirectoryState,
wantDecodedState: expectedHappyDecodedUpstreamStateParamForActiveDirectory(),
},
{
name: "happy POST request for ActiveDirectory upstream",
method: http.MethodPost,
path: newRequestPath().WithState(happyActiveDirectoryState).String(),
csrfCookie: happyCSRFCookie,
wantStatus: http.StatusOK,
wantContentType: htmlContentType,
wantBody: happyPostResult,
wantEncodedState: happyActiveDirectoryState,
wantDecodedState: expectedHappyDecodedUpstreamStateParamForActiveDirectory(),
},
}
for _, test := range tests {
tt := test
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(tt.method, tt.path, nil)
if test.csrfCookie != "" {
req.Header.Set("Cookie", test.csrfCookie)
}
rsp := httptest.NewRecorder()
testGetHandler := func(
w http.ResponseWriter,
r *http.Request,
encodedState string,
decodedState *oidc.UpstreamStateParamData,
) error {
require.Equal(t, req, r)
require.Equal(t, rsp, w)
require.Equal(t, tt.wantEncodedState, encodedState)
require.Equal(t, tt.wantDecodedState, decodedState)
if tt.getHandlerErr == nil {
_, err := w.Write([]byte(happyGetResult))
require.NoError(t, err)
}
return tt.getHandlerErr
}
testPostHandler := func(
w http.ResponseWriter,
r *http.Request,
encodedState string,
decodedState *oidc.UpstreamStateParamData,
) error {
require.Equal(t, req, r)
require.Equal(t, rsp, w)
require.Equal(t, tt.wantEncodedState, encodedState)
require.Equal(t, tt.wantDecodedState, decodedState)
if tt.postHandlerErr == nil {
_, err := w.Write([]byte(happyPostResult))
require.NoError(t, err)
}
return tt.postHandlerErr
}
subject := NewHandler(happyStateCodec, happyCookieCodec, testGetHandler, testPostHandler)
subject.ServeHTTP(rsp, req)
testutil.RequireSecurityHeaders(t, rsp)
require.Equal(t, tt.wantStatus, rsp.Code)
testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), tt.wantContentType)
require.Equal(t, tt.wantBody, rsp.Body.String())
})
}
}
type requestPath struct {
state *string
}
func newRequestPath() *requestPath {
return &requestPath{}
}
func (r *requestPath) WithState(state string) *requestPath {
r.state = &state
return r
}
func (r *requestPath) WithoutState() *requestPath {
r.state = nil
return r
}
func (r *requestPath) String() string {
path := "/login?"
params := url.Values{}
if r.state != nil {
params.Add("state", *r.state)
}
return path + params.Encode()
}

View File

@ -0,0 +1,19 @@
// Copyright 2022 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package login
import (
"net/http"
"github.com/ory/fosite"
"go.pinniped.dev/internal/oidc"
)
func NewPostHandler(upstreamIDPs oidc.UpstreamIdentityProvidersLister, oauthHelper fosite.OAuth2Provider) HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, encodedState string, decodedState *oidc.UpstreamStateParamData) error {
// TODO
return nil
}
}

View File

@ -5,12 +5,15 @@
package oidc package oidc
import ( import (
"crypto/subtle"
"net/http"
"time" "time"
coreosoidc "github.com/coreos/go-oidc/v3/oidc" coreosoidc "github.com/coreos/go-oidc/v3/oidc"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/fosite/compose" "github.com/ory/fosite/compose"
"go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/jwks"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
@ -297,3 +300,68 @@ func ScopeWasRequested(authorizeRequester fosite.AuthorizeRequester, scopeName s
} }
return false return false
} }
func ReadStateParamAndValidateCSRFCookie(r *http.Request, cookieDecoder Decoder, stateDecoder Decoder) (string, *UpstreamStateParamData, error) {
csrfValue, err := readCSRFCookie(r, cookieDecoder)
if err != nil {
return "", nil, err
}
encodedState, decodedState, err := readStateParam(r, stateDecoder)
if err != nil {
return "", nil, err
}
err = validateCSRFValue(decodedState, csrfValue)
if err != nil {
return "", nil, err
}
return encodedState, decodedState, nil
}
func readCSRFCookie(r *http.Request, cookieDecoder Decoder) (csrftoken.CSRFToken, error) {
receivedCSRFCookie, err := r.Cookie(CSRFCookieName)
if err != nil {
// Error means that the cookie was not found
return "", httperr.Wrap(http.StatusForbidden, "CSRF cookie is missing", err)
}
var csrfFromCookie csrftoken.CSRFToken
err = cookieDecoder.Decode(CSRFCookieEncodingName, receivedCSRFCookie.Value, &csrfFromCookie)
if err != nil {
return "", httperr.Wrap(http.StatusForbidden, "error reading CSRF cookie", err)
}
return csrfFromCookie, nil
}
func readStateParam(r *http.Request, stateDecoder Decoder) (string, *UpstreamStateParamData, error) {
encodedState := r.FormValue("state")
if encodedState == "" {
return "", nil, httperr.New(http.StatusBadRequest, "state param not found")
}
var state UpstreamStateParamData
if err := stateDecoder.Decode(
UpstreamStateParamEncodingName,
r.FormValue("state"),
&state,
); err != nil {
return "", nil, httperr.New(http.StatusBadRequest, "error reading state")
}
if state.FormatVersion != UpstreamStateParamFormatVersion {
return "", nil, httperr.New(http.StatusUnprocessableEntity, "state format version is invalid")
}
return encodedState, &state, nil
}
func validateCSRFValue(state *UpstreamStateParamData, csrfCookieValue csrftoken.CSRFToken) error {
if subtle.ConstantTimeCompare([]byte(state.CSRFToken), []byte(csrfCookieValue)) != 1 {
return httperr.New(http.StatusForbidden, "CSRF value does not match")
}
return nil
}

View File

@ -136,7 +136,12 @@ func (m *Manager) SetProviders(federationDomains ...*provider.FederationDomainIs
oauthHelperWithKubeStorage, oauthHelperWithKubeStorage,
) )
m.providerHandlers[(issuerHostWithPath + oidc.PinnipedLoginPath)] = login.NewHandler() m.providerHandlers[(issuerHostWithPath + oidc.PinnipedLoginPath)] = login.NewHandler(
upstreamStateEncoder,
csrfCookieEncoder,
login.NewGetHandler(m.upstreamIDPs),
login.NewPostHandler(m.upstreamIDPs, oauthHelperWithKubeStorage),
)
plog.Debug("oidc provider manager added or updated issuer", "issuer", issuer) plog.Debug("oidc provider manager added or updated issuer", "issuer", issuer)
} }

View File

@ -15,6 +15,7 @@ import (
"time" "time"
coreosoidc "github.com/coreos/go-oidc/v3/oidc" coreosoidc "github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/securecookie"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -837,6 +838,44 @@ type ExpectedUpstreamStateParamFormat struct {
V string `json:"v"` V string `json:"v"`
} }
type UpstreamStateParamBuilder ExpectedUpstreamStateParamFormat
func (b UpstreamStateParamBuilder) Build(t *testing.T, stateEncoder *securecookie.SecureCookie) string {
state, err := stateEncoder.Encode("s", b)
require.NoError(t, err)
return state
}
func (b *UpstreamStateParamBuilder) WithAuthorizeRequestParams(params string) *UpstreamStateParamBuilder {
b.P = params
return b
}
func (b *UpstreamStateParamBuilder) WithNonce(nonce string) *UpstreamStateParamBuilder {
b.N = nonce
return b
}
func (b *UpstreamStateParamBuilder) WithCSRF(csrf string) *UpstreamStateParamBuilder {
b.C = csrf
return b
}
func (b *UpstreamStateParamBuilder) WithPKCE(pkce string) *UpstreamStateParamBuilder {
b.K = pkce
return b
}
func (b *UpstreamStateParamBuilder) WithUpstreamIDPType(upstreamIDPType string) *UpstreamStateParamBuilder {
b.T = upstreamIDPType
return b
}
func (b *UpstreamStateParamBuilder) WithStateVersion(version string) *UpstreamStateParamBuilder {
b.V = version
return b
}
type staticKeySet struct { type staticKeySet struct {
publicKey crypto.PublicKey publicKey crypto.PublicKey
} }