Implement the rest of an OIDC client CLI library.

Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
Matt Moyer 2020-10-06 17:27:36 -05:00
parent ce49d8bd7b
commit 67b692b11f
No known key found for this signature in database
GPG Key ID: EAE88AD172C5AE2D
18 changed files with 1199 additions and 286 deletions

View File

@ -4,101 +4,75 @@
package cmd package cmd
import ( import (
"fmt" "encoding/json"
"github.com/coreos/go-oidc"
"github.com/pkg/browser"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/oauth2" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
clientauthenticationv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1"
"go.pinniped.dev/internal/oidc/pkce" "go.pinniped.dev/internal/oidcclient/login"
"go.pinniped.dev/internal/oidc/state"
) )
//nolint: gochecknoinits //nolint: gochecknoinits
func init() { func init() {
loginCmd.AddCommand((&oidcLoginParams{ loginCmd.AddCommand(oidcLoginCommand(login.Run))
generateState: state.Generate,
generatePKCE: pkce.Generate,
openURL: browser.OpenURL,
}).cmd())
} }
type oidcLoginParams struct { func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...login.Option) (*login.Token, error)) *cobra.Command {
// These parameters capture CLI flags. var (
issuer string cmd = cobra.Command{
clientID string Args: cobra.NoArgs,
listenPort uint16 Use: "oidc --issuer ISSUER --client-id CLIENT_ID",
scopes []string Short: "Login using an OpenID Connect provider",
skipBrowser bool SilenceUsage: true,
}
// These parameters capture dependencies that we want to mock during testing. issuer string
generateState func() (state.State, error) clientID string
generatePKCE func() (pkce.Code, error) listenPort uint16
openURL func(string) error scopes []string
} skipBrowser bool
)
func (o *oidcLoginParams) cmd() *cobra.Command { cmd.Flags().StringVar(&issuer, "issuer", "", "OpenID Connect issuer URL.")
cmd := cobra.Command{ cmd.Flags().StringVar(&clientID, "client-id", "", "OpenID Connect client ID.")
Args: cobra.NoArgs, cmd.Flags().Uint16Var(&listenPort, "listen-port", 0, "TCP port for localhost listener (authorization code flow only).")
Use: "oidc --issuer ISSUER --client-id CLIENT_ID", cmd.Flags().StringSliceVar(&scopes, "scopes", []string{"offline_access", "openid", "email", "profile"}, "OIDC scopes to request during login.")
Short: "Login using an OpenID Connect provider", cmd.Flags().BoolVar(&skipBrowser, "skip-browser", false, "Skip opening the browser (just print the URL).")
RunE: o.runE,
SilenceUsage: true,
}
cmd.Flags().StringVar(&o.issuer, "issuer", "", "OpenID Connect issuer URL.")
cmd.Flags().StringVar(&o.clientID, "client-id", "", "OpenID Connect client ID.")
cmd.Flags().Uint16Var(&o.listenPort, "listen-port", 48095, "TCP port for localhost listener (authorization code flow only).")
cmd.Flags().StringSliceVar(&o.scopes, "scopes", []string{"offline_access", "openid", "email", "profile"}, "OIDC scopes to request during login.")
cmd.Flags().BoolVar(&o.skipBrowser, "skip-browser", false, "Skip opening the browser (just print the URL).")
mustMarkRequired(&cmd, "issuer", "client-id") mustMarkRequired(&cmd, "issuer", "client-id")
cmd.RunE = func(cmd *cobra.Command, args []string) error {
opts := []login.Option{
login.WithContext(cmd.Context()),
login.WithScopes(scopes),
}
if listenPort != 0 {
opts = append(opts, login.WithListenPort(listenPort))
}
// --skip-browser replaces the default "browser open" function with one that prints to stderr.
if skipBrowser {
opts = append(opts, login.WithBrowserOpen(func(url string) error {
cmd.PrintErr("Please log in: ", url, "\n")
return nil
}))
}
tok, err := loginFunc(issuer, clientID, opts...)
if err != nil {
return err
}
// Convert the token out to Kubernetes ExecCredential JSON format for output.
return json.NewEncoder(cmd.OutOrStdout()).Encode(&clientauthenticationv1beta1.ExecCredential{
TypeMeta: metav1.TypeMeta{
Kind: "ExecCredential",
APIVersion: "client.authentication.k8s.io/v1beta1",
},
Status: &clientauthenticationv1beta1.ExecCredentialStatus{
ExpirationTimestamp: &metav1.Time{Time: tok.IDTokenExpiry},
Token: tok.IDToken,
},
})
}
return &cmd return &cmd
} }
func (o *oidcLoginParams) runE(cmd *cobra.Command, args []string) error {
metadata, err := oidc.NewProvider(cmd.Context(), o.issuer)
if err != nil {
return fmt.Errorf("could not perform OIDC discovery for %q: %w", o.issuer, err)
}
cfg := oauth2.Config{
ClientID: o.clientID,
Endpoint: metadata.Endpoint(),
RedirectURL: fmt.Sprintf("http://localhost:%d/callback", o.listenPort),
Scopes: o.scopes,
}
authCodeOptions := []oauth2.AuthCodeOption{oauth2.AccessTypeOffline}
stateParam, err := o.generateState()
if err != nil {
return fmt.Errorf("could not generate OIDC state parameter: %w", err)
}
// We can always use PKCE (RFC 7636) because the server should always ignore the parameters if it doesn't
// understand them. Per https://tools.ietf.org/html/rfc7636#section-5:
// As the OAuth 2.0 [RFC6749] server responses are unchanged by this specification, client implementations of
// this specification do not need to know if the server has implemented this specification or not and SHOULD
// send the additional parameters as defined in Section 4 to all servers.
pkceCode, err := o.generatePKCE()
if err != nil {
return fmt.Errorf("could not generate OIDC PKCE parameter: %w", err)
}
authCodeOptions = append(authCodeOptions, pkceCode.Challenge(), pkceCode.Method())
// If --skip-browser was passed, override the default browser open function with a Printf() call.
openURL := o.openURL
if o.skipBrowser {
openURL = func(s string) error {
cmd.PrintErr("Please log in: ", s, "\n")
return nil
}
}
authorizeURL := cfg.AuthCodeURL(stateParam.String(), authCodeOptions...)
if err := openURL(authorizeURL); err != nil {
return fmt.Errorf("could not open browser (run again with --skip-browser?): %w", err)
}
return nil
}

View File

@ -5,30 +5,29 @@ package cmd
import ( import (
"bytes" "bytes"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"testing" "testing"
"time"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.pinniped.dev/internal/here" "go.pinniped.dev/internal/here"
"go.pinniped.dev/internal/oidc/pkce" "go.pinniped.dev/internal/oidcclient/login"
"go.pinniped.dev/internal/oidc/state"
) )
func TestLoginOIDCCommand(t *testing.T) { func TestLoginOIDCCommand(t *testing.T) {
t.Parallel() t.Parallel()
time1 := time.Date(3020, 10, 12, 13, 14, 15, 16, time.UTC)
tests := []struct { tests := []struct {
name string name string
args []string args []string
wantError bool wantError bool
wantStdout string wantStdout string
wantStderr string wantStderr string
wantIssuer string
wantClientID string
wantOptionsCount int
}{ }{
{ {
name: "help flag passed", name: "help flag passed",
@ -43,7 +42,7 @@ func TestLoginOIDCCommand(t *testing.T) {
--client-id string OpenID Connect client ID. --client-id string OpenID Connect client ID.
-h, --help help for oidc -h, --help help for oidc
--issuer string OpenID Connect issuer URL. --issuer string OpenID Connect issuer URL.
--listen-port uint16 TCP port for localhost listener (authorization code flow only). (default 48095) --listen-port uint16 TCP port for localhost listener (authorization code flow only).
--scopes strings OIDC scopes to request during login. (default [offline_access,openid,email,profile]) --scopes strings OIDC scopes to request during login. (default [offline_access,openid,email,profile])
--skip-browser Skip opening the browser (just print the URL). --skip-browser Skip opening the browser (just print the URL).
`), `),
@ -56,12 +55,46 @@ func TestLoginOIDCCommand(t *testing.T) {
Error: required flag(s) "client-id", "issuer" not set Error: required flag(s) "client-id", "issuer" not set
`), `),
}, },
{
name: "success with minimal options",
args: []string{
"--client-id", "test-client-id",
"--issuer", "test-issuer",
},
wantIssuer: "test-issuer",
wantClientID: "test-client-id",
wantOptionsCount: 2,
wantStdout: `{"kind":"ExecCredential","apiVersion":"client.authentication.k8s.io/v1beta1","spec":{},"status":{"expirationTimestamp":"3020-10-12T13:14:15Z","token":"test-id-token"}}` + "\n",
},
{
name: "success with all options",
args: []string{
"--client-id", "test-client-id",
"--issuer", "test-issuer",
"--skip-browser",
"--listen-port", "1234",
},
wantIssuer: "test-issuer",
wantClientID: "test-client-id",
wantOptionsCount: 4,
wantStdout: `{"kind":"ExecCredential","apiVersion":"client.authentication.k8s.io/v1beta1","spec":{},"status":{"expirationTimestamp":"3020-10-12T13:14:15Z","token":"test-id-token"}}` + "\n",
},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
cmd := (&oidcLoginParams{}).cmd() var (
gotIssuer string
gotClientID string
gotOptions []login.Option
)
cmd := oidcLoginCommand(func(issuer string, clientID string, opts ...login.Option) (*login.Token, error) {
gotIssuer = issuer
gotClientID = clientID
gotOptions = opts
return &login.Token{IDToken: "test-id-token", IDTokenExpiry: time1}, nil
})
require.NotNil(t, cmd) require.NotNil(t, cmd)
var stdout, stderr bytes.Buffer var stdout, stderr bytes.Buffer
@ -76,148 +109,9 @@ func TestLoginOIDCCommand(t *testing.T) {
} }
require.Equal(t, tt.wantStdout, stdout.String(), "unexpected stdout") require.Equal(t, tt.wantStdout, stdout.String(), "unexpected stdout")
require.Equal(t, tt.wantStderr, stderr.String(), "unexpected stderr") require.Equal(t, tt.wantStderr, stderr.String(), "unexpected stderr")
}) require.Equal(t, tt.wantIssuer, gotIssuer, "unexpected issuer")
} require.Equal(t, tt.wantClientID, gotClientID, "unexpected client ID")
} require.Len(t, gotOptions, tt.wantOptionsCount)
func TestOIDCLoginRunE(t *testing.T) {
t.Parallel()
// Start a server that returns 500 errors.
brokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}))
t.Cleanup(brokenServer.Close)
// Start a server that returns successfully.
var validResponse string
validServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(validResponse))
}))
t.Cleanup(validServer.Close)
validResponse = strings.ReplaceAll(here.Docf(`
{
"issuer": "${ISSUER}",
"authorization_endpoint": "${ISSUER}/auth",
"token_endpoint": "${ISSUER}/token",
"jwks_uri": "${ISSUER}/keys",
"userinfo_endpoint": "${ISSUER}/userinfo",
"grant_types_supported": ["authorization_code","refresh_token"],
"response_types_supported": ["code"],
"id_token_signing_alg_values_supported": ["RS256"],
"scopes_supported": ["openid","email","groups","profile","offline_access"],
"token_endpoint_auth_methods_supported": ["client_secret_basic"],
"claims_supported": ["aud","email","email_verified","exp","iat","iss","locale","name","sub"]
}
`), "${ISSUER}", validServer.URL)
validServerURL, err := url.Parse(validServer.URL)
require.NoError(t, err)
tests := []struct {
name string
params oidcLoginParams
wantError string
wantStdout string
wantStderr string
wantStderrAuthURL func(*testing.T, *url.URL)
}{
{
name: "broken discovery",
params: oidcLoginParams{
issuer: brokenServer.URL,
},
wantError: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: Internal Server Error\n", brokenServer.URL),
},
{
name: "broken state generation",
params: oidcLoginParams{
issuer: validServer.URL,
generateState: func() (state.State, error) { return "", fmt.Errorf("some error generating a state value") },
},
wantError: "could not generate OIDC state parameter: some error generating a state value",
},
{
name: "broken PKCE generation",
params: oidcLoginParams{
issuer: validServer.URL,
generateState: func() (state.State, error) { return "test-state", nil },
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some error generating a PKCE code") },
},
wantError: "could not generate OIDC PKCE parameter: some error generating a PKCE code",
},
{
name: "broken browser open",
params: oidcLoginParams{
issuer: validServer.URL,
generateState: func() (state.State, error) { return "test-state", nil },
generatePKCE: func() (pkce.Code, error) { return "test-pkce", nil },
openURL: func(_ string) error { return fmt.Errorf("some browser open error") },
},
wantError: "could not open browser (run again with --skip-browser?): some browser open error",
},
{
name: "success",
params: oidcLoginParams{
issuer: validServer.URL,
clientID: "test-client-id",
generateState: func() (state.State, error) { return "test-state", nil },
generatePKCE: func() (pkce.Code, error) { return "test-pkce", nil },
listenPort: 12345,
skipBrowser: true,
},
wantStderrAuthURL: func(t *testing.T, actual *url.URL) {
require.Equal(t, validServerURL.Host, actual.Host)
require.Equal(t, "/auth", actual.Path)
require.Equal(t, "", actual.Fragment)
require.Equal(t, url.Values{
"access_type": []string{"offline"},
"client_id": []string{"test-client-id"},
"redirect_uri": []string{"http://localhost:12345/callback"},
"response_type": []string{"code"},
"state": []string{"test-state"},
"code_challenge_method": []string{"S256"},
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
}, actual.Query())
},
wantStderr: "Please log in: <URL>\n",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
var stdout, stderr bytes.Buffer
cmd := cobra.Command{RunE: tt.params.runE, SilenceUsage: true, SilenceErrors: true}
cmd.SetOut(&stdout)
cmd.SetErr(&stderr)
err := cmd.Execute()
if tt.wantError != "" {
require.EqualError(t, err, tt.wantError)
} else {
require.NoError(t, err)
}
if tt.wantStderrAuthURL != nil {
var urls []string
redacted := regexp.MustCompile(`http://\S+`).ReplaceAllStringFunc(stderr.String(), func(url string) string {
urls = append(urls, url)
return "<URL>"
})
require.Lenf(t, urls, 1, "expected to find authorization URL in stderr:\n%s", stderr.String())
authURL, err := url.Parse(urls[0])
require.NoError(t, err, "invalid authorization URL")
tt.wantStderrAuthURL(t, authURL)
// Replace the stderr buffer with the redacted version.
stderr.Reset()
stderr.WriteString(redacted)
}
require.Equal(t, tt.wantStdout, stdout.String(), "unexpected stdout")
require.Equal(t, tt.wantStderr, stderr.String(), "unexpected stderr")
}) })
} }
} }

2
go.mod
View File

@ -13,7 +13,6 @@ require (
github.com/golangci/golangci-lint v1.31.0 github.com/golangci/golangci-lint v1.31.0
github.com/google/go-cmp v0.5.2 github.com/google/go-cmp v0.5.2
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4
github.com/pkg/errors v0.9.1
github.com/sclevine/spec v1.4.0 github.com/sclevine/spec v1.4.0
github.com/spf13/cobra v1.0.0 github.com/spf13/cobra v1.0.0
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
@ -22,6 +21,7 @@ require (
go.pinniped.dev/generated/1.19/client v0.0.0-00010101000000-000000000000 go.pinniped.dev/generated/1.19/client v0.0.0-00010101000000-000000000000
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6 golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6
gopkg.in/square/go-jose.v2 v2.2.2
k8s.io/api v0.19.2 k8s.io/api v0.19.2
k8s.io/apimachinery v0.19.2 k8s.io/apimachinery v0.19.2
k8s.io/apiserver v0.19.2 k8s.io/apiserver v0.19.2

View File

@ -0,0 +1,67 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// Package httperr contains some helpers for nicer error handling in http.Handler implementations.
package httperr
import (
"fmt"
"net/http"
)
// Responder represents an error that can emit a useful HTTP error response to an http.ResponseWriter.
type Responder interface {
error
Respond(http.ResponseWriter)
}
// New returns a Responder that emits the given HTTP status code and message.
func New(code int, msg string) error {
return httpErr{code: code, msg: msg}
}
// Newf returns a Responder that emits the given HTTP status code and fmt.Sprintf formatted message.
func Newf(code int, format string, args ...interface{}) error {
return httpErr{code: code, msg: fmt.Sprintf(format, args...)}
}
// Wrap returns a Responder that emits the given HTTP status code and message, and also wraps an internal error.
func Wrap(code int, msg string, cause error) error {
return httpErr{code: code, msg: msg, cause: cause}
}
type httpErr struct {
code int
msg string
cause error
}
func (e httpErr) Error() string {
if e.cause != nil {
return fmt.Sprintf("%s: %v", e.msg, e.cause)
}
return e.msg
}
func (e httpErr) Respond(w http.ResponseWriter) {
// http.Error is important here because it prevents content sniffing by forcing text/plain.
http.Error(w, http.StatusText(e.code)+": "+e.msg, e.code)
}
func (e httpErr) Unwrap() error {
return e.cause
}
// HandlerFunc is like http.HandlerFunc, but with a function signature that allows easier error handling.
type HandlerFunc func(http.ResponseWriter, *http.Request) error
func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch err := f(w, r).(type) {
case nil:
return
case Responder:
err.Respond(w)
default:
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
}

View File

@ -0,0 +1,52 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package httperr
import (
"errors"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestHTTPErrs(t *testing.T) {
t.Run("new", func(t *testing.T) {
err := New(http.StatusBadRequest, "bad request error")
require.EqualError(t, err, "bad request error")
})
t.Run("newf", func(t *testing.T) {
err := Newf(http.StatusMethodNotAllowed, "expected method %s", "POST")
require.EqualError(t, err, "expected method POST")
})
t.Run("wrap", func(t *testing.T) {
wrappedErr := fmt.Errorf("some internal error")
err := Wrap(http.StatusInternalServerError, "unexpected error", wrappedErr)
require.EqualError(t, err, "unexpected error: some internal error")
require.True(t, errors.Is(err, wrappedErr), "expected error to be wrapped")
})
t.Run("respond", func(t *testing.T) {
err := Wrap(http.StatusForbidden, "boring public bits", fmt.Errorf("some secret internal bits"))
require.Implements(t, (*Responder)(nil), err)
rec := httptest.NewRecorder()
err.(Responder).Respond(rec)
require.Equal(t, http.StatusForbidden, rec.Code)
require.Equal(t, "Forbidden: boring public bits\n", rec.Body.String())
require.Equal(t, http.Header{
"Content-Type": []string{"text/plain; charset=utf-8"},
"X-Content-Type-Options": []string{"nosniff"},
}, rec.Header())
})
}
func TestHandlerFunc(t *testing.T) {
t.Run("success", func(t *testing.T) {
})
}

View File

@ -0,0 +1,20 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// Package securityheader implements an HTTP middleware for setting security-related response headers.
package securityheader
import "net/http"
// Wrap the provided http.Handler so it sets appropriate security-related response headers.
func Wrap(wrapped http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
h.Set("Content-Security-Policy", "default-src 'none'; frame-ancestors 'none'")
h.Set("X-Frame-Options", "DENY")
h.Set("X-XSS-Protection", "1; mode=block")
h.Set("X-Content-Type-Options", "nosniff")
h.Set("Referrer-Policy", "no-referrer")
wrapped.ServeHTTP(w, r)
})
}

View File

@ -0,0 +1,30 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package securityheader
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestWrap(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("hello world"))
})
rec := httptest.NewRecorder()
Wrap(handler).ServeHTTP(rec, httptest.NewRequest(http.MethodGet, "/", nil))
require.Equal(t, http.StatusOK, rec.Code)
require.Equal(t, "hello world", rec.Body.String())
require.EqualValues(t, http.Header{
"Content-Security-Policy": []string{"default-src 'none'; frame-ancestors 'none'"},
"Content-Type": []string{"text/plain; charset=utf-8"},
"Referrer-Policy": []string{"no-referrer"},
"X-Content-Type-Options": []string{"nosniff"},
"X-Frame-Options": []string{"DENY"},
"X-Xss-Protection": []string{"1; mode=block"},
}, rec.Header())
}

View File

@ -0,0 +1,6 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package mockkeyset
//go:generate go run -v github.com/golang/mock/mockgen -destination=mockkeyset.go -package=mockkeyset -copyright_file=../../../hack/header.txt github.com/coreos/go-oidc KeySet

View File

@ -0,0 +1,53 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/coreos/go-oidc (interfaces: KeySet)
// Package mockkeyset is a generated GoMock package.
package mockkeyset
import (
context "context"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockKeySet is a mock of KeySet interface
type MockKeySet struct {
ctrl *gomock.Controller
recorder *MockKeySetMockRecorder
}
// MockKeySetMockRecorder is the mock recorder for MockKeySet
type MockKeySetMockRecorder struct {
mock *MockKeySet
}
// NewMockKeySet creates a new mock instance
func NewMockKeySet(ctrl *gomock.Controller) *MockKeySet {
mock := &MockKeySet{ctrl: ctrl}
mock.recorder = &MockKeySetMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockKeySet) EXPECT() *MockKeySetMockRecorder {
return m.recorder
}
// VerifySignature mocks base method
func (m *MockKeySet) VerifySignature(arg0 context.Context, arg1 string) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "VerifySignature", arg0, arg1)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// VerifySignature indicates an expected call of VerifySignature
func (mr *MockKeySetMockRecorder) VerifySignature(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifySignature", reflect.TypeOf((*MockKeySet)(nil).VerifySignature), arg0, arg1)
}

View File

@ -1,37 +0,0 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package state
import (
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"io"
"github.com/pkg/errors"
)
// Generate generates a new random state parameter of an appropriate size.
func Generate() (State, error) { return generate(rand.Reader) }
func generate(rand io.Reader) (State, error) {
var buf [16]byte
if _, err := io.ReadFull(rand, buf[:]); err != nil {
return "", errors.WithMessage(err, "could not generate random state")
}
return State(hex.EncodeToString(buf[:])), nil
}
// State implements some utilities for working with OAuth2 state parameters.
type State string
// String returns the string encoding of this state value.
func (s *State) String() string {
return string(*s)
}
// Validate the returned state (from a callback parameter). Returns true iff the state is valid.
func (s *State) Valid(returnedState string) bool {
return subtle.ConstantTimeCompare([]byte(returnedState), []byte(*s)) == 1
}

View File

@ -0,0 +1,271 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// Package login implements a CLI OIDC login flow.
package login
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"time"
"github.com/coreos/go-oidc"
"github.com/pkg/browser"
"golang.org/x/oauth2"
"go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/httputil/securityheader"
"go.pinniped.dev/internal/oidcclient/nonce"
"go.pinniped.dev/internal/oidcclient/pkce"
"go.pinniped.dev/internal/oidcclient/state"
)
type handlerState struct {
// Basic parameters.
ctx context.Context
issuer string
clientID string
scopes []string
// Parameters of the localhost listener.
listenAddr string
callbackPath string
// Generated parameters of a login flow.
idTokenVerifier *oidc.IDTokenVerifier
oauth2Config *oauth2.Config
state state.State
nonce nonce.Nonce
pkce pkce.Code
// External calls for things.
generateState func() (state.State, error)
generatePKCE func() (pkce.Code, error)
generateNonce func() (nonce.Nonce, error)
openURL func(string) error
callbacks chan callbackResult
}
type callbackResult struct {
token *Token
err error
}
type Token struct {
*oauth2.Token
IDToken string `json:"id_token"`
IDTokenExpiry time.Time `json:"id_token_expiry"`
}
// Option is an optional configuration for Run().
type Option func(*handlerState) error
// WithContext specifies a specific context.Context under which to perform the login. If this option is not specified,
// login happens under context.Background().
func WithContext(ctx context.Context) Option {
return func(h *handlerState) error {
h.ctx = ctx
return nil
}
}
// WithListenPort specifies a TCP listen port on localhost, which will be used for the redirect_uri and to handle the
// authorization code callback. By default, a random high port will be chosen which requires the authorization server
// to support wildcard port numbers as described by https://tools.ietf.org/html/rfc8252:
// The authorization server MUST allow any port to be specified at the
// time of the request for loopback IP redirect URIs, to accommodate
// clients that obtain an available ephemeral port from the operating
// system at the time of the request.
func WithListenPort(port uint16) Option {
return func(h *handlerState) error {
h.listenAddr = fmt.Sprintf("localhost:%d", port)
return nil
}
}
// WithScopes sets the OAuth2 scopes to request during login. If not specified, it defaults to
// "offline_access openid email profile".
func WithScopes(scopes []string) Option {
return func(h *handlerState) error {
h.scopes = scopes
return nil
}
}
// WithBrowserOpen overrides the default "open browser" functionality with a custom callback. If not specified,
// an implementation using https://github.com/pkg/browser will be used by default.
func WithBrowserOpen(openURL func(url string) error) Option {
return func(h *handlerState) error {
h.openURL = openURL
return nil
}
}
// Run an OAuth2/OIDC authorization code login using a localhost listener.
func Run(issuer string, clientID string, opts ...Option) (*Token, error) {
h := handlerState{
issuer: issuer,
clientID: clientID,
listenAddr: "localhost:0",
scopes: []string{"offline_access", "openid", "email", "profile"},
callbackPath: "/callback",
ctx: context.Background(),
callbacks: make(chan callbackResult),
// Default implementations of external dependencies (to be mocked in tests).
generateState: state.Generate,
generateNonce: nonce.Generate,
generatePKCE: pkce.Generate,
openURL: browser.OpenURL,
}
for _, opt := range opts {
if err := opt(&h); err != nil {
return nil, err
}
}
// Always set a long, but non-infinite timeout for this operation.
ctx, cancel := context.WithTimeout(h.ctx, 10*time.Minute)
defer cancel()
h.ctx = ctx
// Initialize login parameters.
var err error
h.state, err = h.generateState()
if err != nil {
return nil, err
}
h.nonce, err = h.generateNonce()
if err != nil {
return nil, err
}
h.pkce, err = h.generatePKCE()
if err != nil {
return nil, err
}
// Perform OIDC discovery.
provider, err := oidc.NewProvider(h.ctx, h.issuer)
if err != nil {
return nil, fmt.Errorf("could not perform OIDC discovery for %q: %w", h.issuer, err)
}
h.idTokenVerifier = provider.Verifier(&oidc.Config{ClientID: h.clientID})
// Open a TCP listener.
listener, err := net.Listen("tcp", h.listenAddr)
if err != nil {
return nil, fmt.Errorf("could not open callback listener: %w", err)
}
// Build an OAuth2 configuration based on the OIDC discovery data and our callback endpoint.
h.oauth2Config = &oauth2.Config{
ClientID: h.clientID,
Endpoint: provider.Endpoint(),
RedirectURL: (&url.URL{
Scheme: "http",
Host: listener.Addr().String(),
Path: h.callbackPath,
}).String(),
Scopes: h.scopes,
}
// Start a callback server in a background goroutine.
mux := http.NewServeMux()
mux.Handle(h.callbackPath, httperr.HandlerFunc(h.handleAuthCodeCallback))
srv := http.Server{
Handler: securityheader.Wrap(mux),
BaseContext: func(_ net.Listener) context.Context { return h.ctx },
}
go func() { _ = srv.Serve(listener) }()
defer func() {
// Gracefully shut down the server, allowing up to 5 seconds for
// clients to receive any in-flight responses.
shutdownCtx, cancel := context.WithTimeout(h.ctx, 1*time.Second)
_ = srv.Shutdown(shutdownCtx)
cancel()
}()
// Open the authorize URL in the users browser.
authorizeURL := h.oauth2Config.AuthCodeURL(
h.state.String(),
oauth2.AccessTypeOffline,
h.nonce.Param(),
h.pkce.Challenge(),
h.pkce.Method(),
)
if err := h.openURL(authorizeURL); err != nil {
return nil, fmt.Errorf("could not open browser: %w", err)
}
// Wait for either the callback or a timeout.
select {
case <-h.ctx.Done():
return nil, fmt.Errorf("timed out waiting for token callback: %w", h.ctx.Err())
case callback := <-h.callbacks:
if callback.err != nil {
return nil, fmt.Errorf("error handling callback: %w", callback.err)
}
return callback.token, nil
}
}
func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Request) (err error) {
// If we return an error, also report it back over the channel to the main CLI thread.
defer func() {
if err != nil {
h.callbacks <- callbackResult{err: err}
}
}()
// Return HTTP 405 for anything that's not a GET.
if r.Method != http.MethodGet {
return httperr.Newf(http.StatusMethodNotAllowed, "wanted GET")
}
// Validate OAuth2 state and fail if it's incorrect (to block CSRF).
params := r.URL.Query()
if err := h.state.Validate(params.Get("state")); err != nil {
return httperr.New(http.StatusForbidden, "missing or invalid state parameter")
}
// Check for error response parameters.
if errorParam := params.Get("error"); errorParam != "" {
return httperr.Newf(http.StatusBadRequest, "login failed with code %q", errorParam)
}
// Exchange the authorization code for access, ID, and refresh tokens.
oauth2Tok, err := h.oauth2Config.Exchange(r.Context(), params.Get("code"), h.pkce.Verifier())
if err != nil {
return httperr.Wrap(http.StatusBadRequest, "could not complete code exchange", err)
}
// Perform required validations on the returned ID token.
idTok, hasIDTok := oauth2Tok.Extra("id_token").(string)
if !hasIDTok {
return httperr.New(http.StatusBadRequest, "received response missing ID token")
}
validated, err := h.idTokenVerifier.Verify(r.Context(), idTok)
if err != nil {
return httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
if validated.AccessTokenHash != "" {
if err := validated.VerifyAccessToken(oauth2Tok.AccessToken); err != nil {
return httperr.Wrap(http.StatusBadRequest, "received invalid ID token", err)
}
}
if err := h.nonce.Validate(validated); err != nil {
return httperr.Wrap(http.StatusBadRequest, "received ID token with invalid nonce", err)
}
h.callbacks <- callbackResult{token: &Token{
Token: oauth2Tok,
IDToken: idTok,
IDTokenExpiry: validated.Expiry,
}}
_, _ = w.Write([]byte("you have been logged in and may now close this tab"))
return nil
}

View File

@ -0,0 +1,420 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package login
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/coreos/go-oidc"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/mocks/mockkeyset"
"go.pinniped.dev/internal/oidcclient/nonce"
"go.pinniped.dev/internal/oidcclient/pkce"
"go.pinniped.dev/internal/oidcclient/state"
)
func TestRun(t *testing.T) {
time1 := time.Date(3020, 10, 12, 13, 14, 15, 16, time.UTC)
testToken := Token{
Token: &oauth2.Token{
AccessToken: "test-access-token",
RefreshToken: "test-refresh-token",
Expiry: time1.Add(1 * time.Minute),
},
IDToken: "test-id-token",
IDTokenExpiry: time1.Add(2 * time.Minute),
}
_ = testToken
// Start a test server that returns 500 errors
errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "some discovery error", http.StatusInternalServerError)
}))
t.Cleanup(errorServer.Close)
// Start a test server that returns a real keyset
providerMux := http.NewServeMux()
successServer := httptest.NewServer(providerMux)
t.Cleanup(successServer.Close)
providerMux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("content-type", "application/json")
type providerJSON struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
}
_ = json.NewEncoder(w).Encode(&providerJSON{
Issuer: successServer.URL,
AuthURL: successServer.URL + "/authorize",
TokenURL: successServer.URL + "/token",
JWKSURL: successServer.URL + "/keys",
})
})
tests := []struct {
name string
opt func(t *testing.T) Option
issuer string
clientID string
wantErr string
wantToken *Token
}{
{
name: "option error",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
return fmt.Errorf("some option error")
}
},
wantErr: "some option error",
},
{
name: "error generating state",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "", fmt.Errorf("some error generating state") }
return nil
}
},
wantErr: "some error generating state",
},
{
name: "error generating nonce",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.generateNonce = func() (nonce.Nonce, error) { return "", fmt.Errorf("some error generating nonce") }
return nil
}
},
wantErr: "some error generating nonce",
},
{
name: "error generating PKCE",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.generatePKCE = func() (pkce.Code, error) { return "", fmt.Errorf("some error generating PKCE") }
return nil
}
},
wantErr: "some error generating PKCE",
},
{
name: "discovery failure",
opt: func(t *testing.T) Option {
return func(h *handlerState) error { return nil }
},
issuer: errorServer.URL,
wantErr: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: some discovery error\n", errorServer.URL),
},
{
name: "listen failure",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.listenAddr = "invalid-listen-address"
return nil
}
},
issuer: successServer.URL,
wantErr: "could not open callback listener: listen tcp: address invalid-listen-address: missing port in address",
},
{
name: "browser open failure",
opt: func(t *testing.T) Option {
return WithBrowserOpen(func(url string) error {
return fmt.Errorf("some browser open error")
})
},
issuer: successServer.URL,
wantErr: "could not open browser: some browser open error",
},
{
name: "timeout waiting for callback",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
ctx, cancel := context.WithCancel(h.ctx)
h.ctx = ctx
h.openURL = func(_ string) error {
cancel()
return nil
}
return nil
}
},
issuer: successServer.URL,
wantErr: "timed out waiting for token callback: context canceled",
},
{
name: "callback returns error",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.openURL = func(_ string) error {
go func() {
h.callbacks <- callbackResult{err: fmt.Errorf("some callback error")}
}()
return nil
}
return nil
}
},
issuer: successServer.URL,
wantErr: "error handling callback: some callback error",
},
{
name: "callback returns success",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
h.generateState = func() (state.State, error) { return "test-state", nil }
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
h.openURL = func(actualURL string) error {
parsedActualURL, err := url.Parse(actualURL)
require.NoError(t, err)
actualParams := parsedActualURL.Query()
require.Contains(t, actualParams.Get("redirect_uri"), "http://127.0.0.1:")
actualParams.Del("redirect_uri")
require.Equal(t, url.Values{
// This is the PKCE challenge which is calculated as base64(sha256("test-pkce")). For example:
// $ echo -n test-pkce | shasum -a 256 | cut -d" " -f1 | xxd -r -p | base64 | cut -d"=" -f1
// VVaezYqum7reIhoavCHD1n2d+piN3r/mywoYj7fCR7g
"code_challenge": []string{"VVaezYqum7reIhoavCHD1n2d-piN3r_mywoYj7fCR7g"},
"code_challenge_method": []string{"S256"},
"response_type": []string{"code"},
"scope": []string{"test-scope"},
"nonce": []string{"test-nonce"},
"state": []string{"test-state"},
"access_type": []string{"offline"},
"client_id": []string{"test-client-id"},
}, actualParams)
parsedActualURL.RawQuery = ""
require.Equal(t, successServer.URL+"/authorize", parsedActualURL.String())
go func() {
h.callbacks <- callbackResult{token: &testToken}
}()
return nil
}
return nil
}
},
issuer: successServer.URL,
wantToken: &testToken,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
tok, err := Run(tt.issuer, tt.clientID,
WithContext(context.Background()),
WithListenPort(0),
WithScopes([]string{"test-scope"}),
tt.opt(t),
)
if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
require.Nil(t, tok)
return
}
require.Equal(t, tt.wantToken, tok)
})
}
}
func TestHandleAuthCodeCallback(t *testing.T) {
tests := []struct {
name string
method string
query string
returnIDTok string
wantErr string
wantHTTPStatus int
}{
{
name: "wrong method",
method: "POST",
query: "",
wantErr: "wanted GET",
wantHTTPStatus: http.StatusMethodNotAllowed,
},
{
name: "invalid state",
query: "state=invalid",
wantErr: "missing or invalid state parameter",
wantHTTPStatus: http.StatusForbidden,
},
{
name: "error code from provider",
query: "state=test-state&error=some_error",
wantErr: `login failed with code "some_error"`,
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "invalid code",
query: "state=test-state&code=invalid",
wantErr: "could not complete code exchange: oauth2: cannot fetch token: 403 Forbidden\nResponse: invalid authorization code\n",
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "missing ID token",
query: "state=test-state&code=valid",
returnIDTok: "",
wantErr: "received response missing ID token",
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "invalid ID token",
query: "state=test-state&code=valid",
returnIDTok: "invalid-jwt",
wantErr: "received invalid ID token: oidc: malformed jwt: square/go-jose: compact JWS format must have three parts",
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "invalid access token hash",
query: "state=test-state&code=valid",
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"at_hash": "invalid-at-hash"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdF9oYXNoIjoiaW52YWxpZC1hdC1oYXNoIiwiYXVkIjoidGVzdC1jbGllbnQtaWQiLCJpYXQiOjE2MDIyODM3OTEsImp0aSI6InRlc3QtanRpIiwibmJmIjoxNjAyMjgzNzkxLCJzdWIiOiJ0ZXN0LXVzZXIifQ.jryXr4jiwcf79wBLaHpjdclEYHoUFGhvTu95QyA6Hnk9NQ0x1vsWYurtj7a8uKydNPryC_HNZi9QTAE_tRIJjycseog3695-5y4B4EZlqL-a94rdOtffuF2O_lnPbKvoja9EKNrp0kLBCftFRHhLAEwuP0N9E5padZwPpIGK0yE_JqljnYgCySvzsQu7tasR38yaULny13h3mtp2WRHPG5DrLyuBuF8Z01hSgRi5hGcVpgzTwBgV5-eMaSUCUo-ZDkqUsLQI6dVlaikCSKYZRb53HeexH0tB_R9PJJHY7mIr-rS76kkQEx9pLuVnheIH9Oc6zbdYWg-zWMijopA8Pg",
wantErr: "received invalid ID token: access token hash does not match value in ID token",
wantHTTPStatus: http.StatusBadRequest,
},
{
name: "invalid nonce",
query: "state=test-state&code=valid",
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "invalid-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjI4Mzc0MSwianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDIyODM3NDEsIm5vbmNlIjoiaW52YWxpZC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.PRpq-7j5djaIAkraL-8t8ad9Xm4hM8RW67gyD1VIe0BecWeBFxsTuh3SZVKM9zmcwTgjudsyn8kQOwipDa49IN4PV8FcJA_uUJZi2wiqGJUSTG2K5I89doV_7e0RM1ZYIDDW1G2heKJNW7MbKkX7iEPr7u4MyEzswcPcupbyDA-CQFeL95vgwawoqa6yO94ympTbozqiNfj6Xyw_nHtThQnstjWsJZ9s2mUgppZezZv4HZYTQ7c3e_bzwhWgCzh2CSDJn9_Ra_n_4GcVkpHbsHTP35dFsnf0vactPx6CAu6A1-Apk-BruCktpZ3B4Ercf1UnUOHdGqzQKJtqvB03xQ",
wantHTTPStatus: http.StatusBadRequest,
wantErr: `received ID token with invalid nonce: invalid nonce (expected "test-nonce", got "invalid-nonce")`,
},
{
name: "valid",
query: "state=test-state&code=valid",
// Test JWT generated with https://smallstep.com/docs/cli/crypto/jwt/:
// step crypto keypair key.pub key.priv --kty RSA --no-password --insecure --force && echo '{"nonce": "test-nonce"}' | step crypto jwt sign --key key.priv --aud test-client-id --sub test-user --subtle --kid="test-kid" --jti="test-jti"
returnIDTok: "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImlhdCI6MTYwMjUzMTU2NywianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDI1MzE1NjcsIm5vbmNlIjoidGVzdC1ub25jZSIsInN1YiI6InRlc3QtdXNlciJ9.LbOA31iwJZBM4ayY5Oud-HArLXbmtAIhZv_LazDqbzA2Iw87RxoBemfiPUJeAesdnO1LKSjBwbltZwtjvbLWHp1R5tqrSMr_hl2OyZv1cpEX-9QaTcQILJ5qR00riRLz34ZCQFyF-FfQpP1r4dNqFrxHuiBwKuPE7zogc83ZYJgAQM5Fao9rIRY9JStL_3pURa9JnnSHFlkLvFYv3TKEUyvnW4pWvYZcsGI7mys43vuSjpG7ZSrW3vCxovuIpXYqAhamZL_XexWUsXvi3ej9HNlhnhOFhN4fuPSc0PWDWaN0CLWmoo8gvOdQWo5A4GD4bNGBzjYOd-pYqsDfseRt1Q",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.NoError(t, r.ParseForm())
require.Equal(t, "test-client-id", r.Form.Get("client_id"))
require.Equal(t, "test-pkce", r.Form.Get("code_verifier"))
require.Equal(t, "authorization_code", r.Form.Get("grant_type"))
require.NotEmpty(t, r.Form.Get("code"))
if r.Form.Get("code") != "valid" {
http.Error(w, "invalid authorization code", http.StatusForbidden)
return
}
var response struct {
oauth2.Token
IDToken string `json:"id_token,omitempty"`
}
response.AccessToken = "test-access-token"
response.Expiry = time.Now().Add(time.Hour)
response.IDToken = tt.returnIDTok
w.Header().Set("content-type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(&response))
}))
t.Cleanup(tokenServer.Close)
h := &handlerState{
callbacks: make(chan callbackResult, 1),
state: state.State("test-state"),
pkce: pkce.Code("test-pkce"),
nonce: nonce.Nonce("test-nonce"),
oauth2Config: &oauth2.Config{
ClientID: "test-client-id",
RedirectURL: "http://localhost:12345/callback",
Endpoint: oauth2.Endpoint{
TokenURL: tokenServer.URL,
AuthStyle: oauth2.AuthStyleInParams,
},
},
idTokenVerifier: mockVerifier(),
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
resp := httptest.NewRecorder()
req, err := http.NewRequestWithContext(ctx, "GET", "/test-callback", nil)
require.NoError(t, err)
req.URL.RawQuery = tt.query
if tt.method != "" {
req.Method = tt.method
}
err = h.handleAuthCodeCallback(resp, req)
if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
if tt.wantHTTPStatus != 0 {
rec := httptest.NewRecorder()
err.(httperr.Responder).Respond(rec)
require.Equal(t, tt.wantHTTPStatus, rec.Code)
}
} else {
require.NoError(t, err)
}
select {
case <-time.After(1 * time.Second):
require.Fail(t, "timed out waiting to receive from callbacks channel")
case result := <-h.callbacks:
if tt.wantErr != "" {
require.EqualError(t, result.err, tt.wantErr)
return
}
require.NoError(t, result.err)
require.NotNil(t, result.token)
require.Equal(t, result.token.IDToken, tt.returnIDTok)
}
})
}
}
// mockVerifier returns an *oidc.IDTokenVerifier that validates any correctly serialized JWT without doing much else.
func mockVerifier() *oidc.IDTokenVerifier {
mockKeySet := mockkeyset.NewMockKeySet(gomock.NewController(nil))
mockKeySet.EXPECT().VerifySignature(gomock.Any(), gomock.Any()).
AnyTimes().
DoAndReturn(func(ctx context.Context, jwt string) ([]byte, error) {
jws, err := jose.ParseSigned(jwt)
if err != nil {
return nil, err
}
return jws.UnsafePayloadWithoutVerification(), nil
})
return oidc.NewVerifier("", mockKeySet, &oidc.Config{
SkipIssuerCheck: true,
SkipExpiryCheck: true,
SkipClientIDCheck: true,
})
}

View File

@ -0,0 +1,58 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// Package nonce implements
package nonce
import (
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"fmt"
"io"
"github.com/coreos/go-oidc"
"golang.org/x/oauth2"
)
// Generate generates a new random OIDC nonce parameter of an appropriate size.
func Generate() (Nonce, error) { return generate(rand.Reader) }
func generate(rand io.Reader) (Nonce, error) {
var buf [16]byte
if _, err := io.ReadFull(rand, buf[:]); err != nil {
return "", fmt.Errorf("could not generate random nonce: %w", err)
}
return Nonce(hex.EncodeToString(buf[:])), nil
}
// Nonce implements some utilities for working with OIDC nonce parameters.
type Nonce string
// String returns the string encoding of this state value.
func (n *Nonce) String() string {
return string(*n)
}
// Param returns the OAuth2 auth code parameter for sending the nonce during the authorization request.
func (n *Nonce) Param() oauth2.AuthCodeOption {
return oidc.Nonce(string(*n))
}
// Validate the returned ID token). Returns true iff the nonce matches or the returned JWT does not have a nonce.
func (n *Nonce) Validate(token *oidc.IDToken) error {
if subtle.ConstantTimeCompare([]byte(token.Nonce), []byte(*n)) != 1 {
return InvalidNonceError{Expected: *n, Got: Nonce(token.Nonce)}
}
return nil
}
// InvalidNonceError is returned by Validate when the observed nonce is invalid.
type InvalidNonceError struct {
Expected Nonce
Got Nonce
}
func (e InvalidNonceError) Error() string {
return fmt.Sprintf("invalid nonce (expected %q, got %q)", e.Expected.String(), e.Got.String())
}

View File

@ -0,0 +1,40 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package nonce
import (
"bytes"
"errors"
"net/url"
"testing"
"github.com/coreos/go-oidc"
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
)
func TestNonce(t *testing.T) {
n, err := Generate()
require.NoError(t, err)
require.Len(t, n, 32)
require.Len(t, n.String(), 32)
cfg := oauth2.Config{}
authCodeURL, err := url.Parse(cfg.AuthCodeURL("", n.Param()))
require.NoError(t, err)
require.Equal(t, n.String(), authCodeURL.Query().Get("nonce"))
require.Error(t, n.Validate(&oidc.IDToken{}))
require.NoError(t, n.Validate(&oidc.IDToken{Nonce: string(n)}))
err = n.Validate(&oidc.IDToken{Nonce: string(n) + "x"})
require.Error(t, err)
require.True(t, errors.As(err, &InvalidNonceError{}))
require.Contains(t, err.Error(), string(n)+"x")
var empty bytes.Buffer
n, err = generate(&empty)
require.EqualError(t, err, "could not generate random nonce: EOF")
require.Empty(t, n)
}

View File

@ -8,9 +8,9 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"fmt"
"io" "io"
"github.com/pkg/errors"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
@ -18,9 +18,14 @@ import (
func Generate() (Code, error) { return generate(rand.Reader) } func Generate() (Code, error) { return generate(rand.Reader) }
func generate(rand io.Reader) (Code, error) { func generate(rand io.Reader) (Code, error) {
// From https://tools.ietf.org/html/rfc7636#section-4.1:
// code_verifier = high-entropy cryptographic random STRING using the
// unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"
// from Section 2.3 of [RFC3986], with a minimum length of 43 characters
// and a maximum length of 128 characters.
var buf [32]byte var buf [32]byte
if _, err := io.ReadFull(rand, buf[:]); err != nil { if _, err := io.ReadFull(rand, buf[:]); err != nil {
return "", errors.WithMessage(err, "could not generate PKCE code") return "", fmt.Errorf("could not generate PKCE code: %w", err)
} }
return Code(hex.EncodeToString(buf[:])), nil return Code(hex.EncodeToString(buf[:])), nil
} }

View File

@ -0,0 +1,56 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package state
import (
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"fmt"
"io"
)
// Generate generates a new random state parameter of an appropriate size.
func Generate() (State, error) { return generate(rand.Reader) }
func generate(rand io.Reader) (State, error) {
// From https://tools.ietf.org/html/rfc6749#section-10.12:
// The binding value used for CSRF
// protection MUST contain a non-guessable value (as described in
// Section 10.10), and the user-agent's authenticated state (e.g.,
// session cookie, HTML5 local storage) MUST be kept in a location
// accessible only to the client and the user-agent (i.e., protected by
// same-origin policy).
var buf [16]byte
if _, err := io.ReadFull(rand, buf[:]); err != nil {
return "", fmt.Errorf("could not generate random state: %w", err)
}
return State(hex.EncodeToString(buf[:])), nil
}
// State implements some utilities for working with OAuth2 state parameters.
type State string
// String returns the string encoding of this state value.
func (s *State) String() string {
return string(*s)
}
// Validate the returned state (from a callback parameter).
func (s *State) Validate(returnedState string) error {
if subtle.ConstantTimeCompare([]byte(returnedState), []byte(*s)) != 1 {
return InvalidStateError{Expected: *s, Got: State(returnedState)}
}
return nil
}
// InvalidStateError is returned by Validate when the returned state is invalid.
type InvalidStateError struct {
Expected State
Got State
}
func (e InvalidStateError) Error() string {
return fmt.Sprintf("invalid state (expected %q, got %q)", e.Expected.String(), e.Got.String())
}

View File

@ -5,6 +5,7 @@ package state
import ( import (
"bytes" "bytes"
"errors"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -15,8 +16,11 @@ func TestState(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Len(t, s, 32) require.Len(t, s, 32)
require.Len(t, s.String(), 32) require.Len(t, s.String(), 32)
require.True(t, s.Valid(string(s))) require.NoError(t, s.Validate(string(s)))
require.False(t, s.Valid(string(s)+"x")) err = s.Validate(string(s) + "x")
require.Error(t, err)
require.True(t, errors.As(err, &InvalidStateError{}))
require.Contains(t, err.Error(), string(s)+"x")
var empty bytes.Buffer var empty bytes.Buffer
s, err = generate(&empty) s, err = generate(&empty)