Merge pull request #147 from mattmoyer/oidc-cli

Implement initial steps of OIDC CLI client.
This commit is contained in:
Matt Moyer 2020-10-06 15:20:30 -05:00 committed by GitHub
commit 8012d6a1c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 585 additions and 4 deletions

22
cmd/pinniped/cmd/alpha.go Normal file
View File

@ -0,0 +1,22 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package cmd
import (
"github.com/spf13/cobra"
)
//nolint: gochecknoglobals
var alphaCmd = &cobra.Command{
Use: "alpha",
Short: "alpha",
Long: "alpha subcommands (syntax or flags are still subject to change)",
SilenceUsage: true, // do not print usage message when commands fail
Hidden: true,
}
//nolint: gochecknoinits
func init() {
rootCmd.AddCommand(alphaCmd)
}

View File

@ -0,0 +1,15 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package cmd
import "github.com/spf13/cobra"
// mustMarkRequired marks the given flags as required on the provided cobra.Command. If any of the names are wrong, it panics.
func mustMarkRequired(cmd *cobra.Command, flags ...string) {
for _, flag := range flags {
if err := cmd.MarkFlagRequired(flag); err != nil {
panic(err)
}
}
}

View File

@ -0,0 +1,21 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package cmd
import (
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
)
func TestMustMarkRequired(t *testing.T) {
require.NotPanics(t, func() { mustMarkRequired(&cobra.Command{}) })
require.NotPanics(t, func() {
cmd := &cobra.Command{}
cmd.Flags().String("known-flag", "", "")
mustMarkRequired(cmd, "known-flag")
})
require.Panics(t, func() { mustMarkRequired(&cobra.Command{}, "unknown-flag") })
}

View File

@ -85,15 +85,12 @@ func (c *getKubeConfigCommand) Command() *cobra.Command {
`),
}
cmd.Flags().StringVar(&c.flags.token, "token", "", "Credential to include in the resulting kubeconfig output (Required)")
err := cmd.MarkFlagRequired("token")
if err != nil {
panic(err)
}
cmd.Flags().StringVar(&c.flags.kubeconfig, "kubeconfig", c.flags.kubeconfig, "Path to the kubeconfig file")
cmd.Flags().StringVar(&c.flags.contextOverride, "kubeconfig-context", c.flags.contextOverride, "Kubeconfig context override")
cmd.Flags().StringVar(&c.flags.namespace, "pinniped-namespace", c.flags.namespace, "Namespace in which Pinniped was installed")
cmd.Flags().StringVar(&c.flags.idpType, "idp-type", c.flags.idpType, "Identity provider type (e.g., 'webhook')")
cmd.Flags().StringVar(&c.flags.idpName, "idp-name", c.flags.idpType, "Identity provider name")
mustMarkRequired(cmd, "token")
return cmd
}

21
cmd/pinniped/cmd/login.go Normal file
View File

@ -0,0 +1,21 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package cmd
import (
"github.com/spf13/cobra"
)
//nolint: gochecknoglobals
var loginCmd = &cobra.Command{
Use: "login",
Short: "login",
Long: "Login to a Pinniped server",
SilenceUsage: true, // do not print usage message when commands fail
}
//nolint: gochecknoinits
func init() {
alphaCmd.AddCommand(loginCmd)
}

View File

@ -0,0 +1,126 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package cmd
import (
"fmt"
"github.com/coreos/go-oidc"
"github.com/pkg/browser"
"github.com/spf13/cobra"
"golang.org/x/oauth2"
"go.pinniped.dev/internal/here"
"go.pinniped.dev/internal/oidc/pkce"
"go.pinniped.dev/internal/oidc/state"
)
//nolint: gochecknoinits
func init() {
loginCmd.AddCommand((&oidcLoginParams{
generateState: state.Generate,
generatePKCE: pkce.Generate,
openURL: browser.OpenURL,
}).cmd())
}
type oidcLoginParams struct {
// These parameters capture CLI flags.
issuer string
clientID string
listenPort uint16
scopes []string
skipBrowser bool
usePKCE bool
debugAuthCode bool
// These parameters capture dependencies that we want to mock during testing.
generateState func() (state.State, error)
generatePKCE func() (pkce.Code, error)
openURL func(string) error
}
func (o *oidcLoginParams) cmd() *cobra.Command {
cmd := cobra.Command{
Args: cobra.NoArgs,
Use: "oidc --issuer ISSUER --client-id CLIENT_ID",
Short: "Login using an OpenID Connect provider",
RunE: o.runE,
SilenceUsage: true,
}
cmd.Flags().StringVar(&o.issuer, "issuer", "", "OpenID Connect issuer URL.")
cmd.Flags().StringVar(&o.clientID, "client-id", "", "OpenID Connect client ID.")
cmd.Flags().Uint16Var(&o.listenPort, "listen-port", 48095, "TCP port for localhost listener (authorization code flow only).")
cmd.Flags().StringSliceVar(&o.scopes, "scopes", []string{"offline_access", "openid", "email", "profile"}, "OIDC scopes to request during login.")
cmd.Flags().BoolVar(&o.skipBrowser, "skip-browser", false, "Skip opening the browser (just print the URL).")
cmd.Flags().BoolVar(&o.usePKCE, "use-pkce", true, "Use Proof Key for Code Exchange (RFC 7636) during login.")
mustMarkRequired(&cmd, "issuer", "client-id")
// TODO: temporary
cmd.Flags().BoolVar(&o.debugAuthCode, "debug-auth-code-exchange", true, "Debug the authorization code exchange (temporary).")
_ = cmd.Flags().MarkHidden("debug-auth-code-exchange")
return &cmd
}
func (o *oidcLoginParams) runE(cmd *cobra.Command, args []string) error {
metadata, err := oidc.NewProvider(cmd.Context(), o.issuer)
if err != nil {
return fmt.Errorf("could not perform OIDC discovery for %q: %w", o.issuer, err)
}
cfg := oauth2.Config{
ClientID: o.clientID,
Endpoint: metadata.Endpoint(),
RedirectURL: fmt.Sprintf("http://localhost:%d/callback", o.listenPort),
Scopes: o.scopes,
}
authCodeOptions := []oauth2.AuthCodeOption{oauth2.AccessTypeOffline}
stateParam, err := o.generateState()
if err != nil {
return fmt.Errorf("could not generate OIDC state parameter: %w", err)
}
var pkceCode pkce.Code
if o.usePKCE {
pkceCode, err = o.generatePKCE()
if err != nil {
return fmt.Errorf("could not generate OIDC PKCE parameter: %w", err)
}
authCodeOptions = append(authCodeOptions, pkceCode.Challenge(), pkceCode.Method())
}
// If --skip-browser was passed, override the default browser open function with a Printf() call.
openURL := o.openURL
if o.skipBrowser {
openURL = func(s string) error {
cmd.PrintErr("Please log in: ", s, "\n")
return nil
}
}
authorizeURL := cfg.AuthCodeURL(stateParam.String(), authCodeOptions...)
if err := openURL(authorizeURL); err != nil {
return fmt.Errorf("could not open browser (run again with --skip-browser?): %w", err)
}
// TODO: this temporary so we can complete the auth code exchange manually
if o.debugAuthCode {
cmd.PrintErr(here.Docf(`
DEBUG INFO:
Token URL: %s
State: %s
PKCE: %s
`,
cfg.Endpoint.TokenURL,
stateParam,
pkceCode.Verifier(),
))
}
return nil
}

View File

@ -0,0 +1,220 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package cmd
import (
"bytes"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"regexp"
"strings"
"testing"
"github.com/spf13/cobra"
"github.com/stretchr/testify/require"
"go.pinniped.dev/internal/here"
"go.pinniped.dev/internal/oidc/pkce"
"go.pinniped.dev/internal/oidc/state"
)
func TestLoginOIDCCommand(t *testing.T) {
t.Parallel()
tests := []struct {
name string
args []string
wantError bool
wantStdout string
wantStderr string
}{
{
name: "help flag passed",
args: []string{"--help"},
wantStdout: here.Doc(`
Login using an OpenID Connect provider
Usage:
oidc --issuer ISSUER --client-id CLIENT_ID [flags]
Flags:
--client-id string OpenID Connect client ID.
-h, --help help for oidc
--issuer string OpenID Connect issuer URL.
--listen-port uint16 TCP port for localhost listener (authorization code flow only). (default 48095)
--scopes strings OIDC scopes to request during login. (default [offline_access,openid,email,profile])
--skip-browser Skip opening the browser (just print the URL).
--use-pkce Use Proof Key for Code Exchange (RFC 7636) during login. (default true)
`),
},
{
name: "missing required flags",
args: []string{},
wantError: true,
wantStdout: here.Doc(`
Error: required flag(s) "client-id", "issuer" not set
`),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
cmd := (&oidcLoginParams{}).cmd()
require.NotNil(t, cmd)
var stdout, stderr bytes.Buffer
cmd.SetOut(&stdout)
cmd.SetErr(&stderr)
cmd.SetArgs(tt.args)
err := cmd.Execute()
if tt.wantError {
require.Error(t, err)
} else {
require.NoError(t, err)
}
require.Equal(t, tt.wantStdout, stdout.String(), "unexpected stdout")
require.Equal(t, tt.wantStderr, stderr.String(), "unexpected stderr")
})
}
}
func TestOIDCLoginRunE(t *testing.T) {
t.Parallel()
// Start a server that returns 500 errors.
brokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}))
t.Cleanup(brokenServer.Close)
// Start a server that returns successfully.
var validResponse string
validServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(validResponse))
}))
t.Cleanup(validServer.Close)
validResponse = strings.ReplaceAll(here.Docf(`
{
"issuer": "${ISSUER}",
"authorization_endpoint": "${ISSUER}/auth",
"token_endpoint": "${ISSUER}/token",
"jwks_uri": "${ISSUER}/keys",
"userinfo_endpoint": "${ISSUER}/userinfo",
"grant_types_supported": ["authorization_code","refresh_token"],
"response_types_supported": ["code"],
"id_token_signing_alg_values_supported": ["RS256"],
"scopes_supported": ["openid","email","groups","profile","offline_access"],
"token_endpoint_auth_methods_supported": ["client_secret_basic"],
"claims_supported": ["aud","email","email_verified","exp","iat","iss","locale","name","sub"]
}
`), "${ISSUER}", validServer.URL)
validServerURL, err := url.Parse(validServer.URL)
require.NoError(t, err)
tests := []struct {
name string
params oidcLoginParams
wantError string
wantStdout string
wantStderr string
wantStderrAuthURL func(*testing.T, *url.URL)
}{
{
name: "broken discovery",
params: oidcLoginParams{
issuer: brokenServer.URL,
},
wantError: fmt.Sprintf("could not perform OIDC discovery for %q: 500 Internal Server Error: Internal Server Error\n", brokenServer.URL),
},
{
name: "broken state generation",
params: oidcLoginParams{
issuer: validServer.URL,
generateState: func() (state.State, error) { return "", fmt.Errorf("some error generating a state value") },
},
wantError: "could not generate OIDC state parameter: some error generating a state value",
},
{
name: "broken PKCE generation",
params: oidcLoginParams{
issuer: validServer.URL,
generateState: func() (state.State, error) { return "test-state", nil },
usePKCE: true,
generatePKCE: func() (pkce.Code, error) { return "", fmt.Errorf("some error generating a PKCE code") },
},
wantError: "could not generate OIDC PKCE parameter: some error generating a PKCE code",
},
{
name: "broken browser open",
params: oidcLoginParams{
issuer: validServer.URL,
generateState: func() (state.State, error) { return "test-state", nil },
usePKCE: true,
generatePKCE: func() (pkce.Code, error) { return "test-pkce", nil },
openURL: func(_ string) error { return fmt.Errorf("some browser open error") },
},
wantError: "could not open browser (run again with --skip-browser?): some browser open error",
},
{
name: "success without PKCE",
params: oidcLoginParams{
issuer: validServer.URL,
clientID: "test-client-id",
generateState: func() (state.State, error) { return "test-state", nil },
usePKCE: false,
listenPort: 12345,
skipBrowser: true,
},
wantStderrAuthURL: func(t *testing.T, actual *url.URL) {
require.Equal(t, validServerURL.Host, actual.Host)
require.Equal(t, "/auth", actual.Path)
require.Equal(t, "", actual.Fragment)
require.Equal(t, url.Values{
"access_type": []string{"offline"},
"client_id": []string{"test-client-id"},
"redirect_uri": []string{"http://localhost:12345/callback"},
"response_type": []string{"code"},
"state": []string{"test-state"},
}, actual.Query())
},
wantStderr: "Please log in: <URL>\n",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
var stdout, stderr bytes.Buffer
cmd := cobra.Command{RunE: tt.params.runE, SilenceUsage: true, SilenceErrors: true}
cmd.SetOut(&stdout)
cmd.SetErr(&stderr)
err := cmd.Execute()
if tt.wantError != "" {
require.EqualError(t, err, tt.wantError)
} else {
require.NoError(t, err)
}
if tt.wantStderrAuthURL != nil {
var urls []string
redacted := regexp.MustCompile(`http://\S+`).ReplaceAllStringFunc(stderr.String(), func(url string) string {
urls = append(urls, url)
return "<URL>"
})
require.Lenf(t, urls, 1, "expected to find authorization URL in stderr:\n%s", stderr.String())
authURL, err := url.Parse(urls[0])
require.NoError(t, err, "invalid authorization URL")
tt.wantStderrAuthURL(t, authURL)
// Replace the stderr buffer with the redacted version.
stderr.Reset()
stderr.WriteString(redacted)
}
require.Equal(t, tt.wantStdout, stdout.String(), "unexpected stdout")
require.Equal(t, tt.wantStderr, stderr.String(), "unexpected stderr")
})
}
}

4
go.mod
View File

@ -4,6 +4,7 @@ go 1.14
require (
github.com/MakeNowJust/heredoc/v2 v2.0.1
github.com/coreos/go-oidc v2.2.1+incompatible
github.com/davecgh/go-spew v1.1.1
github.com/ghodss/yaml v1.0.0
github.com/go-logr/logr v0.2.1
@ -11,6 +12,8 @@ require (
github.com/golang/mock v1.4.4
github.com/golangci/golangci-lint v1.31.0
github.com/google/go-cmp v0.5.2
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4
github.com/pkg/errors v0.9.1
github.com/sclevine/spec v1.4.0
github.com/spf13/cobra v1.0.0
github.com/spf13/pflag v1.0.5
@ -18,6 +21,7 @@ require (
go.pinniped.dev/generated/1.19/apis v0.0.0-00010101000000-000000000000
go.pinniped.dev/generated/1.19/client v0.0.0-00010101000000-000000000000
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6
k8s.io/api v0.19.2
k8s.io/apimachinery v0.19.2
k8s.io/apiserver v0.19.2

6
go.sum
View File

@ -88,6 +88,8 @@ github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkE
github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE=
github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
github.com/coreos/go-oidc v2.2.1+incompatible h1:mh48q/BqXqgjVHpy2ZY7WnWAbenxRjsz9N1i1YxjHAk=
github.com/coreos/go-oidc v2.2.1+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc=
github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
github.com/coreos/go-semver v0.3.0 h1:wkHLiw0WNATZnSG7epLsujiMCgPAc9xhjJ4tgnAxmfM=
github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
@ -450,6 +452,8 @@ github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/9
github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU=
github.com/phayes/checkstyle v0.0.0-20170904204023-bfd46e6a821d h1:CdDQnGF8Nq9ocOS/xlSptM1N3BbrA6/kmaep5ggwaIA=
github.com/phayes/checkstyle v0.0.0-20170904204023-bfd46e6a821d/go.mod h1:3OzsM7FXDQlpCiw2j81fOmAwQLnZnLGXVKUzeKQXIAw=
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 h1:49lOXmGaUpV9Fz3gd7TFZY106KVlPVa5jcYD1gaQf98=
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
@ -457,6 +461,7 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/posener/complete v1.1.1/go.mod h1:em0nMJCgc9GFtwrmVmEMR/ZL6WyhyjMBndrE9hABlRI=
github.com/pquerna/cachecontrol v0.0.0-20171018203845-0dec1b30a021 h1:0XM1XL/OFFJjXsYXlG30spTkV/E9+gmd5GD1w2HE8xM=
github.com/pquerna/cachecontrol v0.0.0-20171018203845-0dec1b30a021/go.mod h1:prYjPmNq4d1NPVmpShWobRqXY3q7Vp+80DqgxxUrUIA=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v0.9.3/go.mod h1:/TN21ttK/J9q6uSwhBd54HahCDft0ttaMvbicHlPoso=
@ -835,6 +840,7 @@ gopkg.in/ini.v1 v1.51.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.0.0 h1:1Lc07Kr7qY4U2YPouBjpCLxpiyxIVoxqXgkXLknAOE8=
gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k=
gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
gopkg.in/square/go-jose.v2 v2.2.2 h1:orlkJ3myw8CN1nVQHBFfloD+L3egixIa4FvUP6RosSA=
gopkg.in/square/go-jose.v2 v2.2.2/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=

View File

@ -0,0 +1,45 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package pkce
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"io"
"github.com/pkg/errors"
"golang.org/x/oauth2"
)
// Generate generates a new random PKCE code.
func Generate() (Code, error) { return generate(rand.Reader) }
func generate(rand io.Reader) (Code, error) {
var buf [32]byte
if _, err := io.ReadFull(rand, buf[:]); err != nil {
return "", errors.WithMessage(err, "could not generate PKCE code")
}
return Code(hex.EncodeToString(buf[:])), nil
}
// Code implements the basic options required for RFC 7636: Proof Key for Code Exchange (PKCE).
type Code string
// Challenge returns the OAuth2 auth code parameter for sending the PKCE code challenge.
func (p *Code) Challenge() oauth2.AuthCodeOption {
b := sha256.Sum256([]byte(*p))
return oauth2.SetAuthURLParam("code_challenge", base64.RawURLEncoding.EncodeToString(b[:]))
}
// Method returns the OAuth2 auth code parameter for sending the PKCE code challenge method.
func (p *Code) Method() oauth2.AuthCodeOption {
return oauth2.SetAuthURLParam("code_challenge_method", "S256")
}
// Verifier returns the OAuth2 auth code parameter for sending the PKCE code verifier.
func (p *Code) Verifier() oauth2.AuthCodeOption {
return oauth2.SetAuthURLParam("code_verifier", string(*p))
}

View File

@ -0,0 +1,42 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package pkce
import (
"bytes"
"encoding/base64"
"net/url"
"testing"
"golang.org/x/oauth2"
"github.com/stretchr/testify/require"
)
func TestPKCE(t *testing.T) {
p, err := Generate()
require.NoError(t, err)
cfg := oauth2.Config{}
authCodeURL, err := url.Parse(cfg.AuthCodeURL("", p.Challenge(), p.Method()))
require.NoError(t, err)
// The code_challenge must be 256 bits (sha256) encoded as unpadded urlsafe base64.
chal, err := base64.RawURLEncoding.DecodeString(authCodeURL.Query().Get("code_challenge"))
require.NoError(t, err)
require.Len(t, chal, 32)
// The code_challenge_method must be a fixed value.
require.Equal(t, "S256", authCodeURL.Query().Get("code_challenge_method"))
// The code_verifier param should be 64 hex characters.
verifyURL, err := url.Parse(cfg.AuthCodeURL("", p.Verifier()))
require.NoError(t, err)
require.Regexp(t, `\A[0-9a-f]{64}\z`, verifyURL.Query().Get("code_verifier"))
var empty bytes.Buffer
p, err = generate(&empty)
require.EqualError(t, err, "could not generate PKCE code: EOF")
require.Empty(t, p)
}

View File

@ -0,0 +1,37 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package state
import (
"crypto/rand"
"crypto/subtle"
"encoding/hex"
"io"
"github.com/pkg/errors"
)
// Generate generates a new random state parameter of an appropriate size.
func Generate() (State, error) { return generate(rand.Reader) }
func generate(rand io.Reader) (State, error) {
var buf [16]byte
if _, err := io.ReadFull(rand, buf[:]); err != nil {
return "", errors.WithMessage(err, "could not generate random state")
}
return State(hex.EncodeToString(buf[:])), nil
}
// State implements some utilities for working with OAuth2 state parameters.
type State string
// String returns the string encoding of this state value.
func (s *State) String() string {
return string(*s)
}
// Validate the returned state (from a callback parameter). Returns true iff the state is valid.
func (s *State) Valid(returnedState string) bool {
return subtle.ConstantTimeCompare([]byte(returnedState), []byte(*s)) == 1
}

View File

@ -0,0 +1,25 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package state
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
func TestState(t *testing.T) {
s, err := Generate()
require.NoError(t, err)
require.Len(t, s, 32)
require.Len(t, s.String(), 32)
require.True(t, s.Valid(string(s)))
require.False(t, s.Valid(string(s)+"x"))
var empty bytes.Buffer
s, err = generate(&empty)
require.EqualError(t, err, "could not generate random state: EOF")
require.Empty(t, s)
}