diff --git a/cmd/pinniped/cmd/alpha.go b/cmd/pinniped/cmd/alpha.go new file mode 100644 index 00000000..db27150f --- /dev/null +++ b/cmd/pinniped/cmd/alpha.go @@ -0,0 +1,22 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "github.com/spf13/cobra" +) + +//nolint: gochecknoglobals +var alphaCmd = &cobra.Command{ + Use: "alpha", + Short: "alpha", + Long: "alpha subcommands (syntax or flags are still subject to change)", + SilenceUsage: true, // do not print usage message when commands fail + Hidden: true, +} + +//nolint: gochecknoinits +func init() { + rootCmd.AddCommand(alphaCmd) +} diff --git a/cmd/pinniped/cmd/cobra_util.go b/cmd/pinniped/cmd/cobra_util.go new file mode 100644 index 00000000..9b153a72 --- /dev/null +++ b/cmd/pinniped/cmd/cobra_util.go @@ -0,0 +1,15 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import "github.com/spf13/cobra" + +// mustMarkRequired marks the given flags as required on the provided cobra.Command. If any of the names are wrong, it panics. +func mustMarkRequired(cmd *cobra.Command, flags ...string) { + for _, flag := range flags { + if err := cmd.MarkFlagRequired(flag); err != nil { + panic(err) + } + } +} diff --git a/cmd/pinniped/cmd/cobra_util_test.go b/cmd/pinniped/cmd/cobra_util_test.go new file mode 100644 index 00000000..b44e8550 --- /dev/null +++ b/cmd/pinniped/cmd/cobra_util_test.go @@ -0,0 +1,21 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/require" +) + +func TestMustMarkRequired(t *testing.T) { + require.NotPanics(t, func() { mustMarkRequired(&cobra.Command{}) }) + require.NotPanics(t, func() { + cmd := &cobra.Command{} + cmd.Flags().String("known-flag", "", "") + mustMarkRequired(cmd, "known-flag") + }) + require.Panics(t, func() { mustMarkRequired(&cobra.Command{}, "unknown-flag") }) +} diff --git a/cmd/pinniped/cmd/get_kubeconfig.go b/cmd/pinniped/cmd/get_kubeconfig.go index 0660c556..8ed99b0d 100644 --- a/cmd/pinniped/cmd/get_kubeconfig.go +++ b/cmd/pinniped/cmd/get_kubeconfig.go @@ -85,15 +85,12 @@ func (c *getKubeConfigCommand) Command() *cobra.Command { `), } cmd.Flags().StringVar(&c.flags.token, "token", "", "Credential to include in the resulting kubeconfig output (Required)") - err := cmd.MarkFlagRequired("token") - if err != nil { - panic(err) - } cmd.Flags().StringVar(&c.flags.kubeconfig, "kubeconfig", c.flags.kubeconfig, "Path to the kubeconfig file") cmd.Flags().StringVar(&c.flags.contextOverride, "kubeconfig-context", c.flags.contextOverride, "Kubeconfig context override") cmd.Flags().StringVar(&c.flags.namespace, "pinniped-namespace", c.flags.namespace, "Namespace in which Pinniped was installed") cmd.Flags().StringVar(&c.flags.idpType, "idp-type", c.flags.idpType, "Identity provider type (e.g., 'webhook')") cmd.Flags().StringVar(&c.flags.idpName, "idp-name", c.flags.idpType, "Identity provider name") + mustMarkRequired(cmd, "token") return cmd } diff --git a/cmd/pinniped/cmd/login.go b/cmd/pinniped/cmd/login.go new file mode 100644 index 00000000..2c0ad082 --- /dev/null +++ b/cmd/pinniped/cmd/login.go @@ -0,0 +1,21 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "github.com/spf13/cobra" +) + +//nolint: gochecknoglobals +var loginCmd = &cobra.Command{ + Use: "login", + Short: "login", + Long: "Login to a Pinniped server", + SilenceUsage: true, // do not print usage message when commands fail +} + +//nolint: gochecknoinits +func init() { + alphaCmd.AddCommand(loginCmd) +} diff --git a/cmd/pinniped/cmd/login_oidc.go b/cmd/pinniped/cmd/login_oidc.go new file mode 100644 index 00000000..e9287d43 --- /dev/null +++ b/cmd/pinniped/cmd/login_oidc.go @@ -0,0 +1,126 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "fmt" + + "github.com/coreos/go-oidc" + "github.com/pkg/browser" + "github.com/spf13/cobra" + "golang.org/x/oauth2" + + "go.pinniped.dev/internal/here" + "go.pinniped.dev/internal/oidc/pkce" + "go.pinniped.dev/internal/oidc/state" +) + +//nolint: gochecknoinits +func init() { + loginCmd.AddCommand((&oidcLoginParams{ + generateState: state.Generate, + generatePKCE: pkce.Generate, + openURL: browser.OpenURL, + }).cmd()) +} + +type oidcLoginParams struct { + // These parameters capture CLI flags. + issuer string + clientID string + listenPort uint16 + scopes []string + skipBrowser bool + usePKCE bool + debugAuthCode bool + + // These parameters capture dependencies that we want to mock during testing. + generateState func() (state.State, error) + generatePKCE func() (pkce.Code, error) + openURL func(string) error +} + +func (o *oidcLoginParams) cmd() *cobra.Command { + cmd := cobra.Command{ + Args: cobra.NoArgs, + Use: "oidc --issuer ISSUER --client-id CLIENT_ID", + Short: "Login using an OpenID Connect provider", + RunE: o.runE, + SilenceUsage: true, + } + cmd.Flags().StringVar(&o.issuer, "issuer", "", "OpenID Connect issuer URL.") + cmd.Flags().StringVar(&o.clientID, "client-id", "", "OpenID Connect client ID.") + cmd.Flags().Uint16Var(&o.listenPort, "listen-port", 48095, "TCP port for localhost listener (authorization code flow only).") + cmd.Flags().StringSliceVar(&o.scopes, "scopes", []string{"offline_access", "openid", "email", "profile"}, "OIDC scopes to request during login.") + cmd.Flags().BoolVar(&o.skipBrowser, "skip-browser", false, "Skip opening the browser (just print the URL).") + cmd.Flags().BoolVar(&o.usePKCE, "use-pkce", true, "Use Proof Key for Code Exchange (RFC 7636) during login.") + mustMarkRequired(&cmd, "issuer", "client-id") + + // TODO: temporary + cmd.Flags().BoolVar(&o.debugAuthCode, "debug-auth-code-exchange", true, "Debug the authorization code exchange (temporary).") + _ = cmd.Flags().MarkHidden("debug-auth-code-exchange") + + return &cmd +} + +func (o *oidcLoginParams) runE(cmd *cobra.Command, args []string) error { + metadata, err := oidc.NewProvider(cmd.Context(), o.issuer) + if err != nil { + return fmt.Errorf("could not perform OIDC discovery for %q: %w", o.issuer, err) + } + + cfg := oauth2.Config{ + ClientID: o.clientID, + Endpoint: metadata.Endpoint(), + RedirectURL: fmt.Sprintf("http://localhost:%d/callback", o.listenPort), + Scopes: o.scopes, + } + + authCodeOptions := []oauth2.AuthCodeOption{oauth2.AccessTypeOffline} + + stateParam, err := o.generateState() + if err != nil { + return fmt.Errorf("could not generate OIDC state parameter: %w", err) + } + + var pkceCode pkce.Code + if o.usePKCE { + pkceCode, err = o.generatePKCE() + if err != nil { + return fmt.Errorf("could not generate OIDC PKCE parameter: %w", err) + } + authCodeOptions = append(authCodeOptions, pkceCode.Challenge(), pkceCode.Method()) + } + + // If --skip-browser was passed, override the default browser open function with a Printf() call. + openURL := o.openURL + if o.skipBrowser { + openURL = func(s string) error { + cmd.PrintErr("Please log in: ", s, "\n") + return nil + } + } + + authorizeURL := cfg.AuthCodeURL(stateParam.String(), authCodeOptions...) + if err := openURL(authorizeURL); err != nil { + return fmt.Errorf("could not open browser (run again with --skip-browser?): %w", err) + } + + // TODO: this temporary so we can complete the auth code exchange manually + + if o.debugAuthCode { + cmd.PrintErr(here.Docf(` + DEBUG INFO: + Token URL: %s + State: %s + PKCE: %s + `, + cfg.Endpoint.TokenURL, + stateParam, + pkceCode.Verifier(), + )) + } + + return nil +} diff --git a/cmd/pinniped/cmd/login_oidc_test.go b/cmd/pinniped/cmd/login_oidc_test.go new file mode 100644 index 00000000..9258cc30 --- /dev/null +++ b/cmd/pinniped/cmd/login_oidc_test.go @@ -0,0 +1,220 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "regexp" + "strings" + "testing" + + "github.com/spf13/cobra" + "github.com/stretchr/testify/require" + + "go.pinniped.dev/internal/here" + "go.pinniped.dev/internal/oidc/pkce" + "go.pinniped.dev/internal/oidc/state" +) + +func TestLoginOIDCCommand(t *testing.T) { + t.Parallel() + tests := []struct { + name string + args []string + wantError bool + wantStdout string + wantStderr string + }{ + { + name: "help flag passed", + args: []string{"--help"}, + wantStdout: here.Doc(` + Login using an OpenID Connect provider + + Usage: + oidc --issuer ISSUER --client-id CLIENT_ID [flags] + + Flags: + --client-id string OpenID Connect client ID. + -h, --help help for oidc + --issuer string OpenID Connect issuer URL. + --listen-port uint16 TCP port for localhost listener (authorization code flow only). (default 48095) + --scopes strings OIDC scopes to request during login. (default [offline_access,openid,email,profile]) + --skip-browser Skip opening the browser (just print the URL). + --use-pkce Use Proof Key for Code Exchange (RFC 7636) during login. (default true) + `), + }, + { + name: "missing required flags", + args: []string{}, + wantError: true, + wantStdout: here.Doc(` + Error: required flag(s) "client-id", "issuer" not set + `), + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + cmd := (&oidcLoginParams{}).cmd() + require.NotNil(t, cmd) + + var stdout, stderr bytes.Buffer + cmd.SetOut(&stdout) + cmd.SetErr(&stderr) + cmd.SetArgs(tt.args) + err := cmd.Execute() + if tt.wantError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.wantStdout, stdout.String(), "unexpected stdout") + require.Equal(t, tt.wantStderr, stderr.String(), "unexpected stderr") + }) + } +} + +func TestOIDCLoginRunE(t *testing.T) { + t.Parallel() + + // Start a server that returns 500 errors. + brokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + })) + t.Cleanup(brokenServer.Close) + + // Start a server that returns successfully. + var validResponse string + validServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(validResponse)) + })) + t.Cleanup(validServer.Close) + validResponse = strings.ReplaceAll(here.Docf(` + { + "issuer": "${ISSUER}", + "authorization_endpoint": "${ISSUER}/auth", + "token_endpoint": "${ISSUER}/token", + "jwks_uri": "${ISSUER}/keys", + "userinfo_endpoint": "${ISSUER}/userinfo", + "grant_types_supported": ["authorization_code","refresh_token"], + "response_types_supported": ["code"], + "id_token_signing_alg_values_supported": ["RS256"], + "scopes_supported": ["openid","email","groups","profile","offline_access"], + "token_endpoint_auth_methods_supported": ["client_secret_basic"], + "claims_supported": ["aud","email","email_verified","exp","iat","iss","locale","name","sub"] + } + `), "${ISSUER}", validServer.URL) + validServerURL, err := url.Parse(validServer.URL) + require.NoError(t, err) + + tests := []struct { + name string + params oidcLoginParams + wantError string + wantStdout string + wantStderr string + wantStderrAuthURL func(*testing.T, *url.URL) + }{ + { + name: "broken discovery", + params: oidcLoginParams{ + issuer: brokenServer.URL, + }, + wantError: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: Internal Server Error\n", brokenServer.URL), + }, + { + name: "broken state generation", + params: oidcLoginParams{ + issuer: validServer.URL, + generateState: func() (state.State, error) { return "", fmt.Errorf("some error generating a state value") }, + }, + wantError: "could not generate OIDC state parameter: some error generating a state value", + }, + { + name: "broken PKCE generation", + params: oidcLoginParams{ + issuer: validServer.URL, + generateState: func() (state.State, error) { return "test-state", nil }, + usePKCE: true, + generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some error generating a PKCE code") }, + }, + wantError: "could not generate OIDC PKCE parameter: some error generating a PKCE code", + }, + { + name: "broken browser open", + params: oidcLoginParams{ + issuer: validServer.URL, + generateState: func() (state.State, error) { return "test-state", nil }, + usePKCE: true, + generatePKCE: func() (pkce.Code, error) { return "test-pkce", nil }, + openURL: func(_ string) error { return fmt.Errorf("some browser open error") }, + }, + wantError: "could not open browser (run again with --skip-browser?): some browser open error", + }, + { + name: "success without PKCE", + params: oidcLoginParams{ + issuer: validServer.URL, + clientID: "test-client-id", + generateState: func() (state.State, error) { return "test-state", nil }, + usePKCE: false, + listenPort: 12345, + skipBrowser: true, + }, + wantStderrAuthURL: func(t *testing.T, actual *url.URL) { + require.Equal(t, validServerURL.Host, actual.Host) + require.Equal(t, "/auth", actual.Path) + require.Equal(t, "", actual.Fragment) + require.Equal(t, url.Values{ + "access_type": []string{"offline"}, + "client_id": []string{"test-client-id"}, + "redirect_uri": []string{"http://localhost:12345/callback"}, + "response_type": []string{"code"}, + "state": []string{"test-state"}, + }, actual.Query()) + }, + wantStderr: "Please log in: \n", + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + var stdout, stderr bytes.Buffer + cmd := cobra.Command{RunE: tt.params.runE, SilenceUsage: true, SilenceErrors: true} + cmd.SetOut(&stdout) + cmd.SetErr(&stderr) + err := cmd.Execute() + if tt.wantError != "" { + require.EqualError(t, err, tt.wantError) + } else { + require.NoError(t, err) + } + + if tt.wantStderrAuthURL != nil { + var urls []string + redacted := regexp.MustCompile(`http://\S+`).ReplaceAllStringFunc(stderr.String(), func(url string) string { + urls = append(urls, url) + return "" + }) + require.Lenf(t, urls, 1, "expected to find authorization URL in stderr:\n%s", stderr.String()) + authURL, err := url.Parse(urls[0]) + require.NoError(t, err, "invalid authorization URL") + tt.wantStderrAuthURL(t, authURL) + + // Replace the stderr buffer with the redacted version. + stderr.Reset() + stderr.WriteString(redacted) + } + + require.Equal(t, tt.wantStdout, stdout.String(), "unexpected stdout") + require.Equal(t, tt.wantStderr, stderr.String(), "unexpected stderr") + }) + } +} diff --git a/go.mod b/go.mod index cf6f6a2e..a3f63b6b 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.14 require ( github.com/MakeNowJust/heredoc/v2 v2.0.1 + github.com/coreos/go-oidc v2.2.1+incompatible github.com/davecgh/go-spew v1.1.1 github.com/ghodss/yaml v1.0.0 github.com/go-logr/logr v0.2.1 @@ -11,6 +12,8 @@ require ( github.com/golang/mock v1.4.4 github.com/golangci/golangci-lint v1.31.0 github.com/google/go-cmp v0.5.2 + github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 + github.com/pkg/errors v0.9.1 github.com/sclevine/spec v1.4.0 github.com/spf13/cobra v1.0.0 github.com/spf13/pflag v1.0.5 @@ -18,6 +21,7 @@ require ( go.pinniped.dev/generated/1.19/apis v0.0.0-00010101000000-000000000000 go.pinniped.dev/generated/1.19/client v0.0.0-00010101000000-000000000000 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 + golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6 k8s.io/api v0.19.2 k8s.io/apimachinery v0.19.2 k8s.io/apiserver v0.19.2 diff --git a/go.sum b/go.sum index 200cf57f..dae855b9 100644 --- a/go.sum +++ b/go.sum @@ -88,6 +88,8 @@ github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkE github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc= +github.com/coreos/go-oidc v2.2.1+incompatible h1:mh48q/BqXqgjVHpy2ZY7WnWAbenxRjsz9N1i1YxjHAk= +github.com/coreos/go-oidc v2.2.1+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-semver v0.3.0 h1:wkHLiw0WNATZnSG7epLsujiMCgPAc9xhjJ4tgnAxmfM= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= @@ -450,6 +452,8 @@ github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/9 github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/phayes/checkstyle v0.0.0-20170904204023-bfd46e6a821d h1:CdDQnGF8Nq9ocOS/xlSptM1N3BbrA6/kmaep5ggwaIA= github.com/phayes/checkstyle v0.0.0-20170904204023-bfd46e6a821d/go.mod h1:3OzsM7FXDQlpCiw2j81fOmAwQLnZnLGXVKUzeKQXIAw= +github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 h1:49lOXmGaUpV9Fz3gd7TFZY106KVlPVa5jcYD1gaQf98= +github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -457,6 +461,7 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI= +github.com/pquerna/cachecontrol v0.0.0-20171018203845-0dec1b30a021 h1:0XM1XL/OFFJjXsYXlG30spTkV/E9+gmd5GD1w2HE8xM= github.com/pquerna/cachecontrol v0.0.0-20171018203845-0dec1b30a021/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso= @@ -835,6 +840,7 @@ gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8= gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= +gopkg.in/square/go-jose.v2 v2.2.2 h1:orlkJ3myw8CN1nVQHBFfloD+L3egixIa4FvUP6RosSA= gopkg.in/square/go-jose.v2 v2.2.2/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= diff --git a/internal/oidc/pkce/pkce.go b/internal/oidc/pkce/pkce.go new file mode 100644 index 00000000..309b3d4d --- /dev/null +++ b/internal/oidc/pkce/pkce.go @@ -0,0 +1,45 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package pkce + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "io" + + "github.com/pkg/errors" + "golang.org/x/oauth2" +) + +// Generate generates a new random PKCE code. +func Generate() (Code, error) { return generate(rand.Reader) } + +func generate(rand io.Reader) (Code, error) { + var buf [32]byte + if _, err := io.ReadFull(rand, buf[:]); err != nil { + return "", errors.WithMessage(err, "could not generate PKCE code") + } + return Code(hex.EncodeToString(buf[:])), nil +} + +// Code implements the basic options required for RFC 7636: Proof Key for Code Exchange (PKCE). +type Code string + +// Challenge returns the OAuth2 auth code parameter for sending the PKCE code challenge. +func (p *Code) Challenge() oauth2.AuthCodeOption { + b := sha256.Sum256([]byte(*p)) + return oauth2.SetAuthURLParam("code_challenge", base64.RawURLEncoding.EncodeToString(b[:])) +} + +// Method returns the OAuth2 auth code parameter for sending the PKCE code challenge method. +func (p *Code) Method() oauth2.AuthCodeOption { + return oauth2.SetAuthURLParam("code_challenge_method", "S256") +} + +// Verifier returns the OAuth2 auth code parameter for sending the PKCE code verifier. +func (p *Code) Verifier() oauth2.AuthCodeOption { + return oauth2.SetAuthURLParam("code_verifier", string(*p)) +} diff --git a/internal/oidc/pkce/pkce_test.go b/internal/oidc/pkce/pkce_test.go new file mode 100644 index 00000000..be611378 --- /dev/null +++ b/internal/oidc/pkce/pkce_test.go @@ -0,0 +1,42 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package pkce + +import ( + "bytes" + "encoding/base64" + "net/url" + "testing" + + "golang.org/x/oauth2" + + "github.com/stretchr/testify/require" +) + +func TestPKCE(t *testing.T) { + p, err := Generate() + require.NoError(t, err) + + cfg := oauth2.Config{} + authCodeURL, err := url.Parse(cfg.AuthCodeURL("", p.Challenge(), p.Method())) + require.NoError(t, err) + + // The code_challenge must be 256 bits (sha256) encoded as unpadded urlsafe base64. + chal, err := base64.RawURLEncoding.DecodeString(authCodeURL.Query().Get("code_challenge")) + require.NoError(t, err) + require.Len(t, chal, 32) + + // The code_challenge_method must be a fixed value. + require.Equal(t, "S256", authCodeURL.Query().Get("code_challenge_method")) + + // The code_verifier param should be 64 hex characters. + verifyURL, err := url.Parse(cfg.AuthCodeURL("", p.Verifier())) + require.NoError(t, err) + require.Regexp(t, `\A[0-9a-f]{64}\z`, verifyURL.Query().Get("code_verifier")) + + var empty bytes.Buffer + p, err = generate(&empty) + require.EqualError(t, err, "could not generate PKCE code: EOF") + require.Empty(t, p) +} diff --git a/internal/oidc/state/state.go b/internal/oidc/state/state.go new file mode 100644 index 00000000..7d70e51b --- /dev/null +++ b/internal/oidc/state/state.go @@ -0,0 +1,37 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package state + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/hex" + "io" + + "github.com/pkg/errors" +) + +// Generate generates a new random state parameter of an appropriate size. +func Generate() (State, error) { return generate(rand.Reader) } + +func generate(rand io.Reader) (State, error) { + var buf [16]byte + if _, err := io.ReadFull(rand, buf[:]); err != nil { + return "", errors.WithMessage(err, "could not generate random state") + } + return State(hex.EncodeToString(buf[:])), nil +} + +// State implements some utilities for working with OAuth2 state parameters. +type State string + +// String returns the string encoding of this state value. +func (s *State) String() string { + return string(*s) +} + +// Validate the returned state (from a callback parameter). Returns true iff the state is valid. +func (s *State) Valid(returnedState string) bool { + return subtle.ConstantTimeCompare([]byte(returnedState), []byte(*s)) == 1 +} diff --git a/internal/oidc/state/state_test.go b/internal/oidc/state/state_test.go new file mode 100644 index 00000000..ff181839 --- /dev/null +++ b/internal/oidc/state/state_test.go @@ -0,0 +1,25 @@ +// Copyright 2020 the Pinniped contributors. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +package state + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestState(t *testing.T) { + s, err := Generate() + require.NoError(t, err) + require.Len(t, s, 32) + require.Len(t, s.String(), 32) + require.True(t, s.Valid(string(s))) + require.False(t, s.Valid(string(s)+"x")) + + var empty bytes.Buffer + s, err = generate(&empty) + require.EqualError(t, err, "could not generate random state: EOF") + require.Empty(t, s) +}