Cleanup code via TODOs accumulated during token endpoint work

We opened https://github.com/vmware-tanzu/pinniped/issues/254 for the TODO in
dynamicOpenIDConnectECDSAStrategy.GenerateToken().

This commit also ensures that linting and unit tests are passing again.

Signed-off-by: Andrew Keesler <akeesler@vmware.com>
This commit is contained in:
Andrew Keesler 2020-12-04 10:06:55 -05:00
parent 83e0934864
commit 03806629b8
No known key found for this signature in database
GPG Key ID: 27CE0444346F9413
12 changed files with 168 additions and 200 deletions

View File

@ -45,7 +45,7 @@ func NewHandler(
authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), r) authorizeRequester, err := oauthHelper.NewAuthorizeRequest(r.Context(), r)
if err != nil { if err != nil {
plog.Info("authorize request error", fositeErrorForLog(err)...) plog.Info("authorize request error", oidc.FositeErrorForLog(err)...)
oauthHelper.WriteAuthorizeError(w, authorizeRequester, err) oauthHelper.WriteAuthorizeError(w, authorizeRequester, err)
return nil return nil
} }
@ -69,7 +69,7 @@ func NewHandler(
}, },
}) })
if err != nil { if err != nil {
plog.Info("authorize response error", fositeErrorForLog(err)...) plog.Info("authorize response error", oidc.FositeErrorForLog(err)...)
oauthHelper.WriteAuthorizeError(w, authorizeRequester, err) oauthHelper.WriteAuthorizeError(w, authorizeRequester, err)
return nil return nil
} }
@ -232,15 +232,3 @@ func addCSRFSetCookieHeader(w http.ResponseWriter, csrfValue csrftoken.CSRFToken
return nil return nil
} }
func fositeErrorForLog(err error) []interface{} {
rfc6749Error := fosite.ErrorToRFC6749Error(err)
keysAndValues := make([]interface{}, 0)
keysAndValues = append(keysAndValues, "name")
keysAndValues = append(keysAndValues, rfc6749Error.Name)
keysAndValues = append(keysAndValues, "status")
keysAndValues = append(keysAndValues, rfc6749Error.Status())
keysAndValues = append(keysAndValues, "description")
keysAndValues = append(keysAndValues, rfc6749Error.Description)
return keysAndValues
}

View File

@ -4,10 +4,8 @@
package auth package auth
import ( import (
"crypto/ecdsa"
"fmt" "fmt"
"html" "html"
"mime"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -21,8 +19,10 @@ import (
"go.pinniped.dev/internal/here" "go.pinniped.dev/internal/here"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/csrftoken" "go.pinniped.dev/internal/oidc/csrftoken"
"go.pinniped.dev/internal/oidc/jwks"
"go.pinniped.dev/internal/oidc/oidctestutil" "go.pinniped.dev/internal/oidc/oidctestutil"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/testutil"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/pkce"
) )
@ -125,9 +125,9 @@ func TestAuthorizationEndpoint(t *testing.T) {
// Configure fosite the same way that the production code would, using NullStorage to turn off storage. // Configure fosite the same way that the production code would, using NullStorage to turn off storage.
oauthStore := oidc.NullStorage{} oauthStore := oidc.NullStorage{}
hmacSecret := []byte("some secret - must have at least 32 bytes") hmacSecret := []byte("some secret - must have at least 32 bytes")
var signingKeyIsUnused *ecdsa.PrivateKey
require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes")
oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret, signingKeyIsUnused) jwksProviderIsUnused := jwks.NewDynamicJWKSProvider()
oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret, jwksProviderIsUnused)
happyCSRF := "test-csrf" happyCSRF := "test-csrf"
happyPKCE := "test-pkce" happyPKCE := "test-pkce"
@ -725,7 +725,7 @@ func TestAuthorizationEndpoint(t *testing.T) {
t.Logf("response body: %q", rsp.Body.String()) t.Logf("response body: %q", rsp.Body.String())
require.Equal(t, test.wantStatus, rsp.Code) require.Equal(t, test.wantStatus, rsp.Code)
requireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType) testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), test.wantContentType)
actualLocation := rsp.Header().Get("Location") actualLocation := rsp.Header().Get("Location")
if test.wantLocationHeader != "" { if test.wantLocationHeader != "" {
@ -826,22 +826,6 @@ func (*errorReturningEncoder) Encode(_ string, _ interface{}) (string, error) {
return "", fmt.Errorf("some encoding error") return "", fmt.Errorf("some encoding error")
} }
func requireEqualContentType(t *testing.T, actual string, expected string) {
t.Helper()
if expected == "" {
require.Empty(t, actual)
return
}
actualContentType, actualContentTypeParams, err := mime.ParseMediaType(expected)
require.NoError(t, err)
expectedContentType, expectedContentTypeParams, err := mime.ParseMediaType(expected)
require.NoError(t, err)
require.Equal(t, actualContentType, expectedContentType)
require.Equal(t, actualContentTypeParams, expectedContentTypeParams)
}
func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL string, stateParamDecoder oidc.Codec) { func requireEqualDecodedStateParams(t *testing.T, actualURL string, expectedURL string, stateParamDecoder oidc.Codec) {
t.Helper() t.Helper()
actualLocationURL, err := url.Parse(actualURL) actualLocationURL, err := url.Parse(actualURL)

View File

@ -24,6 +24,7 @@ import (
kubetesting "k8s.io/client-go/testing" kubetesting "k8s.io/client-go/testing"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/jwks"
"go.pinniped.dev/internal/oidc/oidctestutil" "go.pinniped.dev/internal/oidc/oidctestutil"
"go.pinniped.dev/internal/testutil" "go.pinniped.dev/internal/testutil"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
@ -433,7 +434,8 @@ func TestCallbackEndpoint(t *testing.T) {
oauthStore := oidc.NewKubeStorage(secrets) oauthStore := oidc.NewKubeStorage(secrets)
hmacSecret := []byte("some secret - must have at least 32 bytes") hmacSecret := []byte("some secret - must have at least 32 bytes")
require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes") require.GreaterOrEqual(t, len(hmacSecret), 32, "fosite requires that hmac secrets have at least 32 bytes")
oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret) jwksProviderIsUnused := jwks.NewDynamicJWKSProvider()
oauthHelper := oidc.FositeOauth2Helper(oauthStore, downstreamIssuer, hmacSecret, jwksProviderIsUnused)
idpListGetter := oidctestutil.NewIDPListGetter(&test.idp) idpListGetter := oidctestutil.NewIDPListGetter(&test.idp)
subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec, happyUpstreamRedirectURI) subject := NewHandler(idpListGetter, oauthHelper, happyStateCodec, happyCookieCodec, happyUpstreamRedirectURI)

View File

@ -18,7 +18,14 @@ import (
"go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/jwks"
) )
// TODO: doc me. // dynamicOpenIDConnectECDSAStrategy is an openid.OpenIDConnectTokenStrategy that can dynamically
// load a signing key to issue ID tokens. We want this dynamic capability since our controllers for
// loading OIDCProvider's and signing keys run in parallel, and thus the signing key might not be
// ready when an OIDCProvider is otherwise ready.
//
// If we ever update OIDCProvider's to hold their signing key, we might not need this type, since we
// could have an invariant that routes to an OIDCProvider's endpoints are only wired up if an
// OIDCProvider has a valid signing key.
type dynamicOpenIDConnectECDSAStrategy struct { type dynamicOpenIDConnectECDSAStrategy struct {
fositeConfig *compose.Config fositeConfig *compose.Config
jwksProvider jwks.DynamicJWKSProvider jwksProvider jwks.DynamicJWKSProvider
@ -61,6 +68,5 @@ func (s *dynamicOpenIDConnectECDSAStrategy) GenerateIDToken(
return "", constable.Error("JWK must be of type ecdsa") return "", constable.Error("JWK must be of type ecdsa")
} }
// todo write story/issue about caching this strategy
return compose.NewOpenIDConnectECDSAStrategy(s.fositeConfig, key).GenerateIDToken(ctx, requester) return compose.NewOpenIDConnectECDSAStrategy(s.fositeConfig, key).GenerateIDToken(ctx, requester)
} }

View File

@ -5,15 +5,13 @@ package oidc
import ( import (
"context" "context"
"crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"fmt" "net/url"
"testing" "testing"
coreosoidc "github.com/coreos/go-oidc"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/fosite/compose" "github.com/ory/fosite/compose"
"github.com/ory/fosite/handler/openid" "github.com/ory/fosite/handler/openid"
@ -22,6 +20,7 @@ import (
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
"go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/jwks"
"go.pinniped.dev/internal/oidc/oidctestutil"
) )
func TestDynamicOpenIDConnectECDSAStrategy(t *testing.T) { func TestDynamicOpenIDConnectECDSAStrategy(t *testing.T) {
@ -30,6 +29,7 @@ func TestDynamicOpenIDConnectECDSAStrategy(t *testing.T) {
clientID = "some-client-id" clientID = "some-client-id"
goodSubject = "some-subject" goodSubject = "some-subject"
goodUsername = "some-username" goodUsername = "some-username"
goodNonce = "some-nonce-that-is-at-least-32-characters-to-meet-entropy-requirements"
) )
ecPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) ecPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
@ -106,6 +106,9 @@ func TestDynamicOpenIDConnectECDSAStrategy(t *testing.T) {
Subject: goodSubject, Subject: goodSubject,
Username: goodUsername, Username: goodUsername,
}, },
Form: url.Values{
"nonce": {goodNonce},
},
} }
idToken, err := s.GenerateIDToken(context.Background(), requester) idToken, err := s.GenerateIDToken(context.Background(), requester)
if test.wantError != "" { if test.wantError != "" {
@ -113,38 +116,16 @@ func TestDynamicOpenIDConnectECDSAStrategy(t *testing.T) {
} else { } else {
require.NoError(t, err) require.NoError(t, err)
// TODO: common-ize this code with token endpoint test.
// TODO: make more assertions about ID token
privateKey, ok := test.wantSigningJWK.Key.(*ecdsa.PrivateKey) privateKey, ok := test.wantSigningJWK.Key.(*ecdsa.PrivateKey)
require.True(t, ok, "wanted private key to be *ecdsa.PrivateKey, but was %T", test.wantSigningJWK) require.True(t, ok, "wanted private key to be *ecdsa.PrivateKey, but was %T", test.wantSigningJWK)
keySet := newStaticKeySet(privateKey.Public()) // Perform a light validation on the token to make sure 1) we passed through the correct
verifyConfig := coreosoidc.Config{ // signing key and 2) we forwarded the fosite.Requester correctly. Token generation is
ClientID: clientID, // tested more expansively in the token endpoint.
SupportedSigningAlgs: []string{coreosoidc.ES256}, token := oidctestutil.VerifyECDSAIDToken(t, goodIssuer, clientID, privateKey, idToken)
} require.Equal(t, goodSubject, token.Subject)
verifier := coreosoidc.NewVerifier(test.issuer, keySet, &verifyConfig) require.Equal(t, goodNonce, token.Nonce)
_, err := verifier.Verify(context.Background(), idToken)
require.NoError(t, err)
} }
}) })
} }
} }
// TODO: de-dep me.
func newStaticKeySet(publicKey crypto.PublicKey) coreosoidc.KeySet {
return &staticKeySet{publicKey}
}
type staticKeySet struct {
publicKey crypto.PublicKey
}
func (s *staticKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) {
jws, err := jose.ParseSigned(jwt)
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
}
return jws.Verify(s.publicKey)
}

View File

@ -134,6 +134,27 @@ func FositeOauth2Helper(
) )
} }
// FositeErrorForLog generates a list of information about the provided Fosite error that can be
// passed to a plog function (e.g., plog.Info()).
//
// Sample usage:
// err := someFositeLibraryFunction()
// if err != nil {
// plog.Info("some error", FositeErrorForLog(err)...)
// ...
// }
func FositeErrorForLog(err error) []interface{} {
rfc6749Error := fosite.ErrorToRFC6749Error(err)
keysAndValues := make([]interface{}, 0)
keysAndValues = append(keysAndValues, "name")
keysAndValues = append(keysAndValues, rfc6749Error.Name)
keysAndValues = append(keysAndValues, "status")
keysAndValues = append(keysAndValues, rfc6749Error.Status())
keysAndValues = append(keysAndValues, "description")
keysAndValues = append(keysAndValues, rfc6749Error.Description)
return keysAndValues
}
type IDPListGetter interface { type IDPListGetter interface {
GetIDPList() []provider.UpstreamOIDCIdentityProviderI GetIDPList() []provider.UpstreamOIDCIdentityProviderI
} }

View File

@ -5,9 +5,16 @@ package oidctestutil
import ( import (
"context" "context"
"crypto"
"crypto/ecdsa"
"fmt"
"net/url" "net/url"
"testing"
coreosoidc "github.com/coreos/go-oidc"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
@ -127,3 +134,41 @@ type ExpectedUpstreamStateParamFormat struct {
K string `json:"k"` K string `json:"k"`
V string `json:"v"` V string `json:"v"`
} }
type staticKeySet struct {
publicKey crypto.PublicKey
}
func newStaticKeySet(publicKey crypto.PublicKey) coreosoidc.KeySet {
return &staticKeySet{publicKey}
}
func (s *staticKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) {
jws, err := jose.ParseSigned(jwt)
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt: %w", err)
}
return jws.Verify(s.publicKey)
}
// VerifyECDSAIDToken verifies that the provided idToken was issued via the provided jwtSigningKey.
// It also performs some light validation on the claims, i.e., it makes sure the provided idToken
// has the provided issuer and clientID.
//
// Further validation can be done via callers via the returned coreosoidc.IDToken.
func VerifyECDSAIDToken(
t *testing.T,
issuer, clientID string,
jwtSigningKey *ecdsa.PrivateKey,
idToken string,
) *coreosoidc.IDToken {
t.Helper()
keySet := newStaticKeySet(jwtSigningKey.Public())
verifyConfig := coreosoidc.Config{ClientID: clientID, SupportedSigningAlgs: []string{coreosoidc.ES256}}
verifier := coreosoidc.NewVerifier(issuer, keySet, &verifyConfig)
token, err := verifier.Verify(context.Background(), idToken)
require.NoError(t, err)
return token
}

View File

@ -5,12 +5,8 @@ package manager
import ( import (
"context" "context"
"crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/sha256"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -18,7 +14,6 @@ import (
"strings" "strings"
"testing" "testing"
coreosoidc "github.com/coreos/go-oidc"
"github.com/sclevine/spec" "github.com/sclevine/spec"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
@ -30,6 +25,7 @@ import (
"go.pinniped.dev/internal/oidc/jwks" "go.pinniped.dev/internal/oidc/jwks"
"go.pinniped.dev/internal/oidc/oidctestutil" "go.pinniped.dev/internal/oidc/oidctestutil"
"go.pinniped.dev/internal/oidc/provider" "go.pinniped.dev/internal/oidc/provider"
"go.pinniped.dev/internal/testutil"
"go.pinniped.dev/pkg/oidcclient/nonce" "go.pinniped.dev/pkg/oidcclient/nonce"
"go.pinniped.dev/pkg/oidcclient/oidctypes" "go.pinniped.dev/pkg/oidcclient/oidctypes"
"go.pinniped.dev/pkg/oidcclient/pkce" "go.pinniped.dev/pkg/oidcclient/pkce"
@ -184,7 +180,6 @@ func TestManager(t *testing.T) {
// Validate ID token is signed by the correct JWK to make sure we wired the token endpoint // Validate ID token is signed by the correct JWK to make sure we wired the token endpoint
// signing key correctly. // signing key correctly.
// TODO: common-ize this code with token endpoint test.
idToken, ok := body["id_token"].(string) idToken, ok := body["id_token"].(string)
r.True(ok, "wanted id_token type to be string, but was %T", body["id_token"]) r.True(ok, "wanted id_token type to be string, but was %T", body["id_token"])
@ -192,11 +187,7 @@ func TestManager(t *testing.T) {
privateKey, ok := jwks.Keys[0].Key.(*ecdsa.PrivateKey) privateKey, ok := jwks.Keys[0].Key.(*ecdsa.PrivateKey)
r.True(ok, "wanted private key to be *ecdsa.PrivateKey, but was %T", jwks.Keys[0].Key) r.True(ok, "wanted private key to be *ecdsa.PrivateKey, but was %T", jwks.Keys[0].Key)
keySet := newStaticKeySet(privateKey.Public()) oidctestutil.VerifyECDSAIDToken(t, jwkIssuer, downstreamClientID, privateKey, idToken)
verifyConfig := coreosoidc.Config{ClientID: downstreamClientID, SupportedSigningAlgs: []string{coreosoidc.ES256}}
verifier := coreosoidc.NewVerifier(jwkIssuer, keySet, &verifyConfig)
_, err := verifier.Verify(context.Background(), idToken)
r.NoError(err)
// Make sure that we wired up the callback endpoint to use kube storage for fosite sessions. // Make sure that we wired up the callback endpoint to use kube storage for fosite sessions.
r.Equal(len(kubeClient.Actions()), numberOfKubeActionsBeforeThisRequest+7, r.Equal(len(kubeClient.Actions()), numberOfKubeActionsBeforeThisRequest+7,
@ -305,7 +296,7 @@ func TestManager(t *testing.T) {
"client_id": []string{downstreamClientID}, "client_id": []string{downstreamClientID},
"state": []string{"some-state-value-that-is-32-byte"}, "state": []string{"some-state-value-that-is-32-byte"},
"nonce": []string{"some-nonce-value-that-is-at-least-32-bytes"}, "nonce": []string{"some-nonce-value-that-is-at-least-32-bytes"},
"code_challenge": []string{doSHA256(downstreamPKCECodeVerifier)}, "code_challenge": []string{testutil.SHA256(downstreamPKCECodeVerifier)},
"code_challenge_method": []string{"S256"}, "code_challenge_method": []string{"S256"},
"redirect_uri": []string{downstreamRedirectURL}, "redirect_uri": []string{downstreamRedirectURL},
}.Encode() }.Encode()
@ -406,24 +397,3 @@ func TestManager(t *testing.T) {
}) })
}) })
} }
func doSHA256(s string) string {
b := sha256.Sum256([]byte(s))
return base64.RawURLEncoding.EncodeToString(b[:])
}
func newStaticKeySet(publicKey crypto.PublicKey) coreosoidc.KeySet {
return &staticKeySet{publicKey}
}
type staticKeySet struct {
publicKey crypto.PublicKey
}
func (s *staticKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) {
jws, err := jose.ParseSigned(jwt)
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
}
return jws.Verify(s.publicKey)
}

View File

@ -11,6 +11,7 @@ import (
"github.com/ory/fosite/handler/openid" "github.com/ory/fosite/handler/openid"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
) )
@ -21,14 +22,14 @@ func NewHandler(
var session openid.DefaultSession var session openid.DefaultSession
accessRequest, err := oauthHelper.NewAccessRequest(r.Context(), r, &session) accessRequest, err := oauthHelper.NewAccessRequest(r.Context(), r, &session)
if err != nil { if err != nil {
plog.Info("token request error", fositeErrorForLog(err)...) plog.Info("token request error", oidc.FositeErrorForLog(err)...)
oauthHelper.WriteAccessError(w, accessRequest, err) oauthHelper.WriteAccessError(w, accessRequest, err)
return nil return nil
} }
accessResponse, err := oauthHelper.NewAccessResponse(r.Context(), accessRequest) accessResponse, err := oauthHelper.NewAccessResponse(r.Context(), accessRequest)
if err != nil { if err != nil {
plog.Info("token response error", fositeErrorForLog(err)...) plog.Info("token response error", oidc.FositeErrorForLog(err)...)
oauthHelper.WriteAccessError(w, accessRequest, err) oauthHelper.WriteAccessError(w, accessRequest, err)
return nil return nil
} }
@ -38,16 +39,3 @@ func NewHandler(
return nil return nil
}) })
} }
// TODO: de-dup me.
func fositeErrorForLog(err error) []interface{} {
rfc6749Error := fosite.ErrorToRFC6749Error(err)
keysAndValues := make([]interface{}, 0)
keysAndValues = append(keysAndValues, "name")
keysAndValues = append(keysAndValues, rfc6749Error.Name)
keysAndValues = append(keysAndValues, "status")
keysAndValues = append(keysAndValues, rfc6749Error.Status())
keysAndValues = append(keysAndValues, "description")
keysAndValues = append(keysAndValues, rfc6749Error.Description)
return keysAndValues
}

View File

@ -5,17 +5,14 @@ package token
import ( import (
"context" "context"
"crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"mime"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
@ -23,7 +20,6 @@ import (
"testing" "testing"
"time" "time"
coreosoidc "github.com/coreos/go-oidc"
"github.com/ory/fosite" "github.com/ory/fosite"
"github.com/ory/fosite/handler/oauth2" "github.com/ory/fosite/handler/oauth2"
"github.com/ory/fosite/handler/openid" "github.com/ory/fosite/handler/openid"
@ -35,6 +31,9 @@ import (
"go.pinniped.dev/internal/here" "go.pinniped.dev/internal/here"
"go.pinniped.dev/internal/oidc" "go.pinniped.dev/internal/oidc"
"go.pinniped.dev/internal/oidc/jwks"
"go.pinniped.dev/internal/oidc/oidctestutil"
"go.pinniped.dev/internal/testutil"
) )
const ( const (
@ -188,7 +187,7 @@ func TestTokenEndpoint(t *testing.T) {
"client_id": {goodClient}, "client_id": {goodClient},
"state": {"some-state-value-that-is-32-byte"}, "state": {"some-state-value-that-is-32-byte"},
"nonce": {goodNonce}, "nonce": {goodNonce},
"code_challenge": {doSHA256(goodPKCECodeVerifier)}, "code_challenge": {testutil.SHA256(goodPKCECodeVerifier)},
"code_challenge_method": {"S256"}, "code_challenge_method": {"S256"},
"redirect_uri": {goodRedirectURI}, "redirect_uri": {goodRedirectURI},
}, },
@ -406,7 +405,7 @@ func TestTokenEndpoint(t *testing.T) {
t.Logf("response body: %q", rsp.Body.String()) t.Logf("response body: %q", rsp.Body.String())
require.Equal(t, test.wantStatus, rsp.Code) require.Equal(t, test.wantStatus, rsp.Code)
requireEqualContentType(t, rsp.Header().Get("Content-Type"), "application/json") testutil.RequireEqualContentType(t, rsp.Header().Get("Content-Type"), "application/json")
if test.wantBodyFields != nil { if test.wantBodyFields != nil {
var m map[string]interface{} var m map[string]interface{}
require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &m)) require.NoError(t, json.Unmarshal(rsp.Body.Bytes(), &m))
@ -444,7 +443,7 @@ func TestTokenEndpoint(t *testing.T) {
subject.ServeHTTP(rsp0, req) subject.ServeHTTP(rsp0, req)
t.Logf("response 0: %#v", rsp0) t.Logf("response 0: %#v", rsp0)
t.Logf("response 0 body: %q", rsp0.Body.String()) t.Logf("response 0 body: %q", rsp0.Body.String())
requireEqualContentType(t, rsp0.Header().Get("Content-Type"), "application/json") testutil.RequireEqualContentType(t, rsp0.Header().Get("Content-Type"), "application/json")
require.Equal(t, http.StatusOK, rsp0.Code) require.Equal(t, http.StatusOK, rsp0.Code)
var m map[string]interface{} var m map[string]interface{}
@ -470,7 +469,7 @@ func TestTokenEndpoint(t *testing.T) {
t.Logf("response 1: %#v", rsp1) t.Logf("response 1: %#v", rsp1)
t.Logf("response 1 body: %q", rsp1.Body.String()) t.Logf("response 1 body: %q", rsp1.Body.String())
require.Equal(t, http.StatusBadRequest, rsp1.Code) require.Equal(t, http.StatusBadRequest, rsp1.Code)
requireEqualContentType(t, rsp1.Header().Get("Content-Type"), "application/json") testutil.RequireEqualContentType(t, rsp1.Header().Get("Content-Type"), "application/json")
require.JSONEq(t, fositeReusedAuthCodeErrorBody, rsp1.Body.String()) require.JSONEq(t, fositeReusedAuthCodeErrorBody, rsp1.Body.String())
requireInvalidAuthCodeStorage(t, code, oauthStore) requireInvalidAuthCodeStorage(t, code, oauthStore)
@ -547,8 +546,8 @@ func makeHappyOauthHelper(
) (fosite.OAuth2Provider, string, *ecdsa.PrivateKey) { ) (fosite.OAuth2Provider, string, *ecdsa.PrivateKey) {
t.Helper() t.Helper()
jwtSigningKey := generateJWTSigningKey(t) jwtSigningKey, jwkProvider := generateJWTSigningKeyAndJWKSProvider(t, goodIssuer)
oauthHelper := oidc.FositeOauth2Helper(store, goodIssuer, []byte(hmacSecret), jwtSigningKey) oauthHelper := oidc.FositeOauth2Helper(store, goodIssuer, []byte(hmacSecret), jwkProvider)
// Simulate the auth endpoint running so Fosite code will fill the store with realistic values. // Simulate the auth endpoint running so Fosite code will fill the store with realistic values.
// //
@ -574,11 +573,21 @@ func makeHappyOauthHelper(
return oauthHelper, authResponder.GetCode(), jwtSigningKey return oauthHelper, authResponder.GetCode(), jwtSigningKey
} }
func generateJWTSigningKey(t *testing.T) *ecdsa.PrivateKey { func generateJWTSigningKeyAndJWKSProvider(t *testing.T, issuer string) (*ecdsa.PrivateKey, jwks.DynamicJWKSProvider) {
t.Helper() t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err) require.NoError(t, err)
return key
jwksProvider := jwks.NewDynamicJWKSProvider()
jwksProvider.SetIssuerToJWKSMap(
nil, // public JWKS unused
map[string]*jose.JSONWebKey{
issuer: {Key: key},
},
)
return key, jwksProvider
} }
func hashAccessToken(accessToken string) string { func hashAccessToken(accessToken string) string {
@ -591,12 +600,6 @@ func hashAccessToken(accessToken string) string {
return base64.RawURLEncoding.EncodeToString(b[:len(b)/2]) return base64.RawURLEncoding.EncodeToString(b[:len(b)/2])
} }
// TODO: de-dup me (manager test).
func doSHA256(s string) string {
b := sha256.Sum256([]byte(s))
return base64.RawURLEncoding.EncodeToString(b[:])
}
func requireInvalidAuthCodeStorage( func requireInvalidAuthCodeStorage(
t *testing.T, t *testing.T,
code string, code string,
@ -736,7 +739,7 @@ func requireValidAuthRequest(
wantGrantedScopes = append([]string{"openid"}, wantGrantedScopes...) wantGrantedScopes = append([]string{"openid"}, wantGrantedScopes...)
} }
require.NotEmpty(t, authRequest.GetID()) require.NotEmpty(t, authRequest.GetID())
requireTimeInDelta(t, authRequest.GetRequestedAt(), time.Now().UTC(), timeComparisonFudgeSeconds*time.Second) testutil.RequireTimeInDelta(t, authRequest.GetRequestedAt(), time.Now().UTC(), timeComparisonFudgeSeconds*time.Second)
require.Equal(t, goodClient, authRequest.GetClient().GetID()) require.Equal(t, goodClient, authRequest.GetClient().GetID())
require.Equal(t, fosite.Arguments(wantRequestedScopes), authRequest.GetRequestedScopes()) require.Equal(t, fosite.Arguments(wantRequestedScopes), authRequest.GetRequestedScopes())
require.Equal(t, fosite.Arguments(wantGrantedScopes), authRequest.GetGrantedScopes()) require.Equal(t, fosite.Arguments(wantGrantedScopes), authRequest.GetGrantedScopes())
@ -756,13 +759,13 @@ func requireValidAuthRequest(
require.Equal(t, goodSubject, claims.Subject) require.Equal(t, goodSubject, claims.Subject)
require.Equal(t, []string{goodClient}, claims.Audience) require.Equal(t, []string{goodClient}, claims.Audience)
require.Equal(t, goodNonce, claims.Nonce) require.Equal(t, goodNonce, claims.Nonce)
requireTimeInDelta( testutil.RequireTimeInDelta(
t, t,
time.Now().UTC().Add(idTokenExpirationSeconds*time.Second), time.Now().UTC().Add(idTokenExpirationSeconds*time.Second),
claims.ExpiresAt, claims.ExpiresAt,
timeComparisonFudgeSeconds*time.Second, timeComparisonFudgeSeconds*time.Second,
) )
requireTimeInDelta(t, time.Now().UTC(), claims.IssuedAt, timeComparisonFudgeSeconds*time.Second) testutil.RequireTimeInDelta(t, time.Now().UTC(), claims.IssuedAt, timeComparisonFudgeSeconds*time.Second)
require.Equal(t, wantAccessTokenHash, claims.AccessTokenHash) require.Equal(t, wantAccessTokenHash, claims.AccessTokenHash)
// We are in charge of setting these fields. For the purpose of testing, we ensure that the // We are in charge of setting these fields. For the purpose of testing, we ensure that the
@ -784,7 +787,7 @@ func requireValidAuthRequest(
// Assert that the token expirations are what we think they should be. // Assert that the token expirations are what we think they should be.
authCodeExpiresAt, ok := session.ExpiresAt[fosite.AuthorizeCode] authCodeExpiresAt, ok := session.ExpiresAt[fosite.AuthorizeCode]
require.True(t, ok, "expected session to hold expiration time for auth code") require.True(t, ok, "expected session to hold expiration time for auth code")
requireTimeInDelta( testutil.RequireTimeInDelta(
t, t,
time.Now().UTC().Add(authCodeExpirationSeconds*time.Second), time.Now().UTC().Add(authCodeExpirationSeconds*time.Second),
authCodeExpiresAt, authCodeExpiresAt,
@ -792,7 +795,7 @@ func requireValidAuthRequest(
) )
accessTokenExpiresAt, ok := session.ExpiresAt[fosite.AccessToken] accessTokenExpiresAt, ok := session.ExpiresAt[fosite.AccessToken]
require.True(t, ok, "expected session to hold expiration time for access token") require.True(t, ok, "expected session to hold expiration time for access token")
requireTimeInDelta( testutil.RequireTimeInDelta(
t, t,
time.Now().UTC().Add(accessTokenExpirationSeconds*time.Second), time.Now().UTC().Add(accessTokenExpirationSeconds*time.Second),
accessTokenExpiresAt, accessTokenExpiresAt,
@ -805,17 +808,15 @@ func requireValidAuthRequest(
} }
func requireValidIDToken(t *testing.T, body map[string]interface{}, jwtSigningKey *ecdsa.PrivateKey) { func requireValidIDToken(t *testing.T, body map[string]interface{}, jwtSigningKey *ecdsa.PrivateKey) {
t.Helper()
idToken, ok := body["id_token"] idToken, ok := body["id_token"]
require.Truef(t, ok, "body did not contain 'id_token': %s", body) require.Truef(t, ok, "body did not contain 'id_token': %s", body)
idTokenString, ok := idToken.(string) idTokenString, ok := idToken.(string)
require.Truef(t, ok, "wanted id_token to be a string, but got %T", idToken) require.Truef(t, ok, "wanted id_token to be a string, but got %T", idToken)
// The go-oidc library will validate the signature and the client claim in the ID token. // The go-oidc library will validate the signature and the client claim in the ID token.
keySet := newStaticKeySet(jwtSigningKey.Public()) token := oidctestutil.VerifyECDSAIDToken(t, goodIssuer, goodClient, jwtSigningKey, idTokenString)
verifyConfig := coreosoidc.Config{ClientID: goodClient, SupportedSigningAlgs: []string{coreosoidc.ES256}}
verifier := coreosoidc.NewVerifier(goodIssuer, keySet, &verifyConfig)
token, err := verifier.Verify(context.Background(), idTokenString)
require.NoError(t, err)
var claims struct { var claims struct {
Subject string `json:"sub"` Subject string `json:"sub"`
@ -837,7 +838,7 @@ func requireValidIDToken(t *testing.T, body map[string]interface{}, jwtSigningKe
require.ElementsMatch(t, idTokenFields, getMapKeys(m)) require.ElementsMatch(t, idTokenFields, getMapKeys(m))
// verify each of the claims // verify each of the claims
err = token.Claims(&claims) err := token.Claims(&claims)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, goodSubject, claims.Subject) require.Equal(t, goodSubject, claims.Subject)
require.Len(t, claims.Audience, 1) require.Len(t, claims.Audience, 1)
@ -851,60 +852,10 @@ func requireValidIDToken(t *testing.T, body map[string]interface{}, jwtSigningKe
issuedAt := time.Unix(claims.IssuedAt, 0) issuedAt := time.Unix(claims.IssuedAt, 0)
requestedAt := time.Unix(claims.RequestedAt, 0) requestedAt := time.Unix(claims.RequestedAt, 0)
authTime := time.Unix(claims.AuthTime, 0) authTime := time.Unix(claims.AuthTime, 0)
requireTimeInDelta(t, time.Now().UTC().Add(idTokenExpirationSeconds*time.Second), expiresAt, timeComparisonFudgeSeconds*time.Second) testutil.RequireTimeInDelta(t, time.Now().UTC().Add(idTokenExpirationSeconds*time.Second), expiresAt, timeComparisonFudgeSeconds*time.Second)
requireTimeInDelta(t, time.Now().UTC(), issuedAt, timeComparisonFudgeSeconds*time.Second) testutil.RequireTimeInDelta(t, time.Now().UTC(), issuedAt, timeComparisonFudgeSeconds*time.Second)
requireTimeInDelta(t, goodRequestedAtTime, requestedAt, timeComparisonFudgeSeconds*time.Second) testutil.RequireTimeInDelta(t, goodRequestedAtTime, requestedAt, timeComparisonFudgeSeconds*time.Second)
requireTimeInDelta(t, goodAuthTime, authTime, timeComparisonFudgeSeconds*time.Second) testutil.RequireTimeInDelta(t, goodAuthTime, authTime, timeComparisonFudgeSeconds*time.Second)
}
// TODO: de-dup me (manager test).
func newStaticKeySet(publicKey crypto.PublicKey) coreosoidc.KeySet {
return &staticKeySet{publicKey}
}
type staticKeySet struct {
publicKey crypto.PublicKey
}
func (s *staticKeySet) VerifySignature(ctx context.Context, jwt string) ([]byte, error) {
jws, err := jose.ParseSigned(jwt)
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt: %v", err)
}
return jws.Verify(s.publicKey)
}
// TODO: de-dup me.
func requireEqualContentType(t *testing.T, actual string, expected string) {
t.Helper()
if expected == "" {
require.Empty(t, actual)
return
}
actualContentType, actualContentTypeParams, err := mime.ParseMediaType(expected)
require.NoError(t, err)
expectedContentType, expectedContentTypeParams, err := mime.ParseMediaType(expected)
require.NoError(t, err)
require.Equal(t, actualContentType, expectedContentType)
require.Equal(t, actualContentTypeParams, expectedContentTypeParams)
}
// TODO: use actual testutil function.
//nolint:unparam
func requireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Duration) {
t.Helper()
require.InDeltaf(t,
float64(t1.UnixNano()),
float64(t2.UnixNano()),
float64(delta.Nanoseconds()),
"expected %s and %s to be < %s apart, but they are %s apart",
t1.Format(time.RFC3339Nano),
t2.Format(time.RFC3339Nano),
delta.String(),
t1.Sub(t2).String(),
)
} }
func deepCopyRequestForm(r *http.Request) *http.Request { func deepCopyRequestForm(r *http.Request) *http.Request {

View File

@ -4,6 +4,7 @@
package testutil package testutil
import ( import (
"mime"
"testing" "testing"
"time" "time"
@ -22,3 +23,19 @@ func RequireTimeInDelta(t *testing.T, t1 time.Time, t2 time.Time, delta time.Dur
t1.Sub(t2).String(), t1.Sub(t2).String(),
) )
} }
func RequireEqualContentType(t *testing.T, actual string, expected string) {
t.Helper()
if expected == "" {
require.Empty(t, actual)
return
}
actualContentType, actualContentTypeParams, err := mime.ParseMediaType(expected)
require.NoError(t, err)
expectedContentType, expectedContentTypeParams, err := mime.ParseMediaType(expected)
require.NoError(t, err)
require.Equal(t, actualContentType, expectedContentType)
require.Equal(t, actualContentTypeParams, expectedContentTypeParams)
}

View File

@ -0,0 +1,15 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package testutil
import (
"crypto/sha256"
"encoding/base64"
)
// SHA256 returns the base64 URL encoding of the SHA256 sum of the provided string.
func SHA256(s string) string {
b := sha256.Sum256([]byte(s))
return base64.RawURLEncoding.EncodeToString(b[:])
}