Merge pull request #165 from mattmoyer/cli-session-cache

Add basic file-based session cache for CLI OIDC client.
This commit is contained in:
Matt Moyer 2020-10-21 16:30:03 -05:00 committed by GitHub
commit 5dbc03efe9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1377 additions and 53 deletions

View File

@ -13,3 +13,12 @@ func mustMarkRequired(cmd *cobra.Command, flags ...string) {
} }
} }
} }
// mustMarkHidden marks the given flags as hidden on the provided cobra.Command. If any of the names are wrong, it panics.
func mustMarkHidden(cmd *cobra.Command, flags ...string) {
for _, flag := range flags {
if err := cmd.Flags().MarkHidden(flag); err != nil {
panic(err)
}
}
}

View File

@ -5,20 +5,24 @@ package cmd
import ( import (
"encoding/json" "encoding/json"
"os"
"path/filepath"
"github.com/spf13/cobra" "github.com/spf13/cobra"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
clientauthenticationv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1" clientauthenticationv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1"
"k8s.io/klog/v2/klogr"
"go.pinniped.dev/internal/oidcclient/login" "go.pinniped.dev/internal/oidcclient"
"go.pinniped.dev/internal/oidcclient/filesession"
) )
//nolint: gochecknoinits //nolint: gochecknoinits
func init() { func init() {
loginCmd.AddCommand(oidcLoginCommand(login.Run)) loginCmd.AddCommand(oidcLoginCommand(oidcclient.Login))
} }
func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...login.Option) (*login.Token, error)) *cobra.Command { func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oidcclient.Option) (*oidcclient.Token, error)) *cobra.Command {
var ( var (
cmd = cobra.Command{ cmd = cobra.Command{
Args: cobra.NoArgs, Args: cobra.NoArgs,
@ -31,27 +35,46 @@ func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...log
listenPort uint16 listenPort uint16
scopes []string scopes []string
skipBrowser bool skipBrowser bool
sessionCachePath string
debugSessionCache bool
) )
cmd.Flags().StringVar(&issuer, "issuer", "", "OpenID Connect issuer URL.") cmd.Flags().StringVar(&issuer, "issuer", "", "OpenID Connect issuer URL.")
cmd.Flags().StringVar(&clientID, "client-id", "", "OpenID Connect client ID.") cmd.Flags().StringVar(&clientID, "client-id", "", "OpenID Connect client ID.")
cmd.Flags().Uint16Var(&listenPort, "listen-port", 0, "TCP port for localhost listener (authorization code flow only).") cmd.Flags().Uint16Var(&listenPort, "listen-port", 0, "TCP port for localhost listener (authorization code flow only).")
cmd.Flags().StringSliceVar(&scopes, "scopes", []string{"offline_access", "openid", "email", "profile"}, "OIDC scopes to request during login.") cmd.Flags().StringSliceVar(&scopes, "scopes", []string{"offline_access", "openid", "email", "profile"}, "OIDC scopes to request during login.")
cmd.Flags().BoolVar(&skipBrowser, "skip-browser", false, "Skip opening the browser (just print the URL).") cmd.Flags().BoolVar(&skipBrowser, "skip-browser", false, "Skip opening the browser (just print the URL).")
cmd.Flags().StringVar(&sessionCachePath, "session-cache", filepath.Join(mustGetConfigDir(), "sessions.yaml"), "Path to session cache file.")
cmd.Flags().BoolVar(&debugSessionCache, "debug-session-cache", false, "Print debug logs related to the session cache.")
mustMarkHidden(&cmd, "debug-session-cache")
mustMarkRequired(&cmd, "issuer", "client-id") mustMarkRequired(&cmd, "issuer", "client-id")
cmd.RunE = func(cmd *cobra.Command, args []string) error { cmd.RunE = func(cmd *cobra.Command, args []string) error {
opts := []login.Option{ // Initialize the session cache.
login.WithContext(cmd.Context()), var sessionOptions []filesession.Option
login.WithScopes(scopes),
// If the hidden --debug-session-cache option is passed, log all the errors from the session cache with klog.
if debugSessionCache {
logger := klogr.New().WithName("session")
sessionOptions = append(sessionOptions, filesession.WithErrorReporter(func(err error) {
logger.Error(err, "error during session cache operation")
}))
}
sessionCache := filesession.New(sessionCachePath, sessionOptions...)
// Initialize the login handler.
opts := []oidcclient.Option{
oidcclient.WithContext(cmd.Context()),
oidcclient.WithScopes(scopes),
oidcclient.WithSessionCache(sessionCache),
} }
if listenPort != 0 { if listenPort != 0 {
opts = append(opts, login.WithListenPort(listenPort)) opts = append(opts, oidcclient.WithListenPort(listenPort))
} }
// --skip-browser replaces the default "browser open" function with one that prints to stderr. // --skip-browser replaces the default "browser open" function with one that prints to stderr.
if skipBrowser { if skipBrowser {
opts = append(opts, login.WithBrowserOpen(func(url string) error { opts = append(opts, oidcclient.WithBrowserOpen(func(url string) error {
cmd.PrintErr("Please log in: ", url, "\n") cmd.PrintErr("Please log in: ", url, "\n")
return nil return nil
})) }))
@ -69,10 +92,27 @@ func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...log
APIVersion: "client.authentication.k8s.io/v1beta1", APIVersion: "client.authentication.k8s.io/v1beta1",
}, },
Status: &clientauthenticationv1beta1.ExecCredentialStatus{ Status: &clientauthenticationv1beta1.ExecCredentialStatus{
ExpirationTimestamp: &metav1.Time{Time: tok.IDTokenExpiry}, ExpirationTimestamp: &tok.IDToken.Expiry,
Token: tok.IDToken, Token: tok.IDToken.Token,
}, },
}) })
} }
return &cmd return &cmd
} }
// mustGetConfigDir returns a directory that follows the XDG base directory convention:
// $XDG_CONFIG_HOME defines the base directory relative to which user specific configuration files should
// be stored. If $XDG_CONFIG_HOME is either not set or empty, a default equal to $HOME/.config should be used.
// [1] https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html
func mustGetConfigDir() string {
const xdgAppName = "pinniped"
if path := os.Getenv("XDG_CONFIG_HOME"); path != "" {
return filepath.Join(path, xdgAppName)
}
home, err := os.UserHomeDir()
if err != nil {
panic(err)
}
return filepath.Join(home, ".config", xdgAppName)
}

View File

@ -9,14 +9,17 @@ import (
"time" "time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/here" "go.pinniped.dev/internal/here"
"go.pinniped.dev/internal/oidcclient/login" "go.pinniped.dev/internal/oidcclient"
) )
func TestLoginOIDCCommand(t *testing.T) { func TestLoginOIDCCommand(t *testing.T) {
t.Parallel() t.Parallel()
cfgDir := mustGetConfigDir()
time1 := time.Date(3020, 10, 12, 13, 14, 15, 16, time.UTC) time1 := time.Date(3020, 10, 12, 13, 14, 15, 16, time.UTC)
tests := []struct { tests := []struct {
@ -44,6 +47,7 @@ func TestLoginOIDCCommand(t *testing.T) {
--issuer string OpenID Connect issuer URL. --issuer string OpenID Connect issuer URL.
--listen-port uint16 TCP port for localhost listener (authorization code flow only). --listen-port uint16 TCP port for localhost listener (authorization code flow only).
--scopes strings OIDC scopes to request during login. (default [offline_access,openid,email,profile]) --scopes strings OIDC scopes to request during login. (default [offline_access,openid,email,profile])
--session-cache string Path to session cache file. (default "` + cfgDir + `/sessions.yaml")
--skip-browser Skip opening the browser (just print the URL). --skip-browser Skip opening the browser (just print the URL).
`), `),
}, },
@ -63,7 +67,7 @@ func TestLoginOIDCCommand(t *testing.T) {
}, },
wantIssuer: "test-issuer", wantIssuer: "test-issuer",
wantClientID: "test-client-id", wantClientID: "test-client-id",
wantOptionsCount: 2, wantOptionsCount: 3,
wantStdout: `{"kind":"ExecCredential","apiVersion":"client.authentication.k8s.io/v1beta1","spec":{},"status":{"expirationTimestamp":"3020-10-12T13:14:15Z","token":"test-id-token"}}` + "\n", wantStdout: `{"kind":"ExecCredential","apiVersion":"client.authentication.k8s.io/v1beta1","spec":{},"status":{"expirationTimestamp":"3020-10-12T13:14:15Z","token":"test-id-token"}}` + "\n",
}, },
{ {
@ -73,10 +77,11 @@ func TestLoginOIDCCommand(t *testing.T) {
"--issuer", "test-issuer", "--issuer", "test-issuer",
"--skip-browser", "--skip-browser",
"--listen-port", "1234", "--listen-port", "1234",
"--debug-session-cache",
}, },
wantIssuer: "test-issuer", wantIssuer: "test-issuer",
wantClientID: "test-client-id", wantClientID: "test-client-id",
wantOptionsCount: 4, wantOptionsCount: 5,
wantStdout: `{"kind":"ExecCredential","apiVersion":"client.authentication.k8s.io/v1beta1","spec":{},"status":{"expirationTimestamp":"3020-10-12T13:14:15Z","token":"test-id-token"}}` + "\n", wantStdout: `{"kind":"ExecCredential","apiVersion":"client.authentication.k8s.io/v1beta1","spec":{},"status":{"expirationTimestamp":"3020-10-12T13:14:15Z","token":"test-id-token"}}` + "\n",
}, },
} }
@ -87,13 +92,18 @@ func TestLoginOIDCCommand(t *testing.T) {
var ( var (
gotIssuer string gotIssuer string
gotClientID string gotClientID string
gotOptions []login.Option gotOptions []oidcclient.Option
) )
cmd := oidcLoginCommand(func(issuer string, clientID string, opts ...login.Option) (*login.Token, error) { cmd := oidcLoginCommand(func(issuer string, clientID string, opts ...oidcclient.Option) (*oidcclient.Token, error) {
gotIssuer = issuer gotIssuer = issuer
gotClientID = clientID gotClientID = clientID
gotOptions = opts gotOptions = opts
return &login.Token{IDToken: "test-id-token", IDTokenExpiry: time1}, nil return &oidcclient.Token{
IDToken: &oidcclient.IDToken{
Token: "test-id-token",
Expiry: metav1.NewTime(time1),
},
}, nil
}) })
require.NotNil(t, cmd) require.NotNil(t, cmd)

1
go.mod
View File

@ -10,6 +10,7 @@ require (
github.com/ghodss/yaml v1.0.0 github.com/ghodss/yaml v1.0.0
github.com/go-logr/logr v0.2.1 github.com/go-logr/logr v0.2.1
github.com/go-logr/stdr v0.2.0 github.com/go-logr/stdr v0.2.0
github.com/gofrs/flock v0.8.0
github.com/golang/mock v1.4.4 github.com/golang/mock v1.4.4
github.com/golangci/golangci-lint v1.31.0 github.com/golangci/golangci-lint v1.31.0
github.com/google/go-cmp v0.5.2 github.com/google/go-cmp v0.5.2

View File

@ -0,0 +1,157 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// Package cachefile implements the file format for session caches.
package filesession
import (
"errors"
"fmt"
"io/ioutil"
"os"
"reflect"
"sort"
"time"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/yaml"
"go.pinniped.dev/internal/oidcclient"
)
var (
// errUnsupportedVersion is returned (internally) when we encounter a version of the session cache file that we
// don't understand how to handle (such as one produced by a future version of Pinniped).
errUnsupportedVersion = fmt.Errorf("unsupported session version")
)
const (
// apiVersion is the Kubernetes-style API version of the session file object.
apiVersion = "config.pinniped.dev/v1alpha1"
// apiKind is the Kubernetes-style Kind of the session file object.
apiKind = "SessionCache"
// sessionExpiration is how long a session can remain unused before it is automatically pruned from the session cache.
sessionExpiration = 90 * 24 * time.Hour
)
type (
// sessionCache is the object which is YAML-serialized to form the contents of the cache file.
sessionCache struct {
metav1.TypeMeta
Sessions []sessionEntry `json:"sessions"`
}
// sessionEntry is a single cache entry in the cache file.
sessionEntry struct {
Key oidcclient.SessionCacheKey `json:"key"`
CreationTimestamp metav1.Time `json:"creationTimestamp"`
LastUsedTimestamp metav1.Time `json:"lastUsedTimestamp"`
Tokens oidcclient.Token `json:"tokens"`
}
)
// readSessionCache loads a sessionCache from a path on disk. If the requested path does not exist, it returns an empty cache.
func readSessionCache(path string) (*sessionCache, error) {
cacheYAML, err := ioutil.ReadFile(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
// If the file was not found, generate a freshly initialized empty cache.
return emptySessionCache(), nil
}
// Otherwise bubble up the error.
return nil, fmt.Errorf("could not read session file: %w", err)
}
// If we read the file successfully, unmarshal it from YAML.
var cache sessionCache
if err := yaml.Unmarshal(cacheYAML, &cache); err != nil {
return nil, fmt.Errorf("invalid session file: %w", err)
}
// Validate that we're reading a version of the config we understand how to parse.
if !(cache.TypeMeta.APIVersion == apiVersion && cache.TypeMeta.Kind == apiKind) {
return nil, fmt.Errorf("%w: %#v", errUnsupportedVersion, cache.TypeMeta)
}
return &cache, nil
}
// emptySessionCache returns an empty, initialized sessionCache.
func emptySessionCache() *sessionCache {
return &sessionCache{
TypeMeta: metav1.TypeMeta{APIVersion: apiVersion, Kind: apiKind},
Sessions: make([]sessionEntry, 0, 1),
}
}
// writeTo writes the cache to the specified file path.
func (c *sessionCache) writeTo(path string) error {
// Marshal the session back to YAML and save it to the file.
cacheYAML, err := yaml.Marshal(c)
if err == nil {
err = ioutil.WriteFile(path, cacheYAML, 0600)
}
return err
}
// normalized returns a copy of the sessionCache with stale entries removed and entries sorted in a canonical order.
func (c *sessionCache) normalized() *sessionCache {
result := emptySessionCache()
// Clean up expired/invalid tokens.
now := time.Now()
result.Sessions = make([]sessionEntry, 0, len(c.Sessions))
for _, s := range c.Sessions {
// Nil out any tokens that are empty or expired.
if s.Tokens.IDToken != nil {
if s.Tokens.IDToken.Token == "" || s.Tokens.IDToken.Expiry.Time.Before(now) {
s.Tokens.IDToken = nil
}
}
if s.Tokens.AccessToken != nil {
if s.Tokens.AccessToken.Token == "" || s.Tokens.AccessToken.Expiry.Time.Before(now) {
s.Tokens.AccessToken = nil
}
}
if s.Tokens.RefreshToken != nil && s.Tokens.RefreshToken.Token == "" {
s.Tokens.RefreshToken = nil
}
// Filter out any entries that no longer contain any tokens.
if s.Tokens.IDToken == nil && s.Tokens.AccessToken == nil && s.Tokens.RefreshToken == nil {
continue
}
// Filter out entries that haven't been used in the last sessionExpiration.
cutoff := metav1.NewTime(now.Add(-1 * sessionExpiration))
if s.LastUsedTimestamp.Before(&cutoff) {
continue
}
result.Sessions = append(result.Sessions, s)
}
// Sort the sessions by creation time.
sort.SliceStable(result.Sessions, func(i, j int) bool {
return result.Sessions[i].CreationTimestamp.Before(&result.Sessions[j].CreationTimestamp)
})
return result
}
// lookup a cache entry by key. May return nil.
func (c *sessionCache) lookup(key oidcclient.SessionCacheKey) *sessionEntry {
for i := range c.Sessions {
if reflect.DeepEqual(c.Sessions[i].Key, key) {
return &c.Sessions[i]
}
}
return nil
}
// insert a cache entry.
func (c *sessionCache) insert(entries ...sessionEntry) {
c.Sessions = append(c.Sessions, entries...)
}

View File

@ -0,0 +1,261 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package filesession
import (
"os"
"testing"
"time"
"github.com/stretchr/testify/require"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/oidcclient"
)
// validSession should be the same data as `testdata/valid.yaml`.
var validSession = sessionCache{
TypeMeta: metav1.TypeMeta{APIVersion: "config.pinniped.dev/v1alpha1", Kind: "SessionCache"},
Sessions: []sessionEntry{
{
Key: oidcclient.SessionCacheKey{
Issuer: "test-issuer",
ClientID: "test-client-id",
Scopes: []string{"email", "offline_access", "openid", "profile"},
RedirectURI: "http://localhost:0/callback",
},
CreationTimestamp: metav1.NewTime(time.Date(2020, 10, 20, 18, 42, 7, 0, time.UTC).Local()),
LastUsedTimestamp: metav1.NewTime(time.Date(2020, 10, 20, 18, 45, 31, 0, time.UTC).Local()),
Tokens: oidcclient.Token{
AccessToken: &oidcclient.AccessToken{
Token: "test-access-token",
Type: "Bearer",
Expiry: metav1.NewTime(time.Date(2020, 10, 20, 19, 46, 30, 0, time.UTC).Local()),
},
IDToken: &oidcclient.IDToken{
Token: "test-id-token",
Expiry: metav1.NewTime(time.Date(2020, 10, 20, 19, 42, 07, 0, time.UTC).Local()),
},
RefreshToken: &oidcclient.RefreshToken{
Token: "test-refresh-token",
},
},
},
},
}
func TestReadSessionCache(t *testing.T) {
t.Parallel()
tests := []struct {
name string
path string
want *sessionCache
wantErr string
}{
{
name: "does not exist",
path: "./testdata/does-not-exist.yaml",
want: &sessionCache{
TypeMeta: metav1.TypeMeta{APIVersion: "config.pinniped.dev/v1alpha1", Kind: "SessionCache"},
Sessions: []sessionEntry{},
},
},
{
name: "other file error",
path: "./testdata/",
wantErr: "could not read session file: read ./testdata/: is a directory",
},
{
name: "invalid YAML",
path: "./testdata/invalid.yaml",
wantErr: "invalid session file: error unmarshaling JSON: while decoding JSON: json: cannot unmarshal string into Go value of type filesession.sessionCache",
},
{
name: "wrong version",
path: "./testdata/wrong-version.yaml",
wantErr: `unsupported session version: v1.TypeMeta{Kind:"NotASessionCache", APIVersion:"config.pinniped.dev/v2alpha6"}`,
},
{
name: "valid",
path: "./testdata/valid.yaml",
want: &validSession,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := readSessionCache(tt.path)
if tt.wantErr != "" {
require.EqualError(t, err, tt.wantErr)
require.Nil(t, got)
return
}
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, tt.want, got)
})
}
}
func TestEmptySessionCache(t *testing.T) {
t.Parallel()
got := emptySessionCache()
require.Equal(t, metav1.TypeMeta{APIVersion: "config.pinniped.dev/v1alpha1", Kind: "SessionCache"}, got.TypeMeta)
require.Equal(t, 0, len(got.Sessions))
require.Equal(t, 1, cap(got.Sessions))
}
func TestWriteTo(t *testing.T) {
t.Parallel()
t.Run("io error", func(t *testing.T) {
t.Parallel()
tmp := t.TempDir() + "/sessions.yaml"
require.NoError(t, os.Mkdir(tmp, 0700))
err := validSession.writeTo(tmp)
require.EqualError(t, err, "open "+tmp+": is a directory")
})
t.Run("success", func(t *testing.T) {
t.Parallel()
require.NoError(t, validSession.writeTo(t.TempDir()+"/sessions.yaml"))
})
}
func TestNormalized(t *testing.T) {
t.Parallel()
t.Run("empty", func(t *testing.T) {
t.Parallel()
require.Equal(t, emptySessionCache(), emptySessionCache().normalized())
})
t.Run("nonempty", func(t *testing.T) {
t.Parallel()
input := emptySessionCache()
now := time.Now()
input.Sessions = []sessionEntry{
// ID token is empty, but not nil.
{
LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{
IDToken: &oidcclient.IDToken{
Token: "",
Expiry: metav1.NewTime(now.Add(1 * time.Minute)),
},
},
},
// ID token is expired.
{
LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{
IDToken: &oidcclient.IDToken{
Token: "test-id-token",
Expiry: metav1.NewTime(now.Add(-1 * time.Minute)),
},
},
},
// Access token is empty, but not nil.
{
LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{
AccessToken: &oidcclient.AccessToken{
Token: "",
Expiry: metav1.NewTime(now.Add(1 * time.Minute)),
},
},
},
// Access token is expired.
{
LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{
AccessToken: &oidcclient.AccessToken{
Token: "test-access-token",
Expiry: metav1.NewTime(now.Add(-1 * time.Minute)),
},
},
},
// Refresh token is empty, but not nil.
{
LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{
RefreshToken: &oidcclient.RefreshToken{
Token: "",
},
},
},
// Session has a refresh token but it hasn't been used in >90 days.
{
LastUsedTimestamp: metav1.NewTime(now.AddDate(-1, 0, 0)),
Tokens: oidcclient.Token{
RefreshToken: &oidcclient.RefreshToken{
Token: "test-refresh-token",
},
},
},
// Two entries that are still valid.
{
CreationTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{
RefreshToken: &oidcclient.RefreshToken{
Token: "test-refresh-token2",
},
},
},
{
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{
RefreshToken: &oidcclient.RefreshToken{
Token: "test-refresh-token1",
},
},
},
}
// Expect that all but the last two valid session are pruned, and that they're sorted.
require.Equal(t, &sessionCache{
TypeMeta: metav1.TypeMeta{APIVersion: "config.pinniped.dev/v1alpha1", Kind: "SessionCache"},
Sessions: []sessionEntry{
{
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{
RefreshToken: &oidcclient.RefreshToken{
Token: "test-refresh-token1",
},
},
},
{
CreationTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
LastUsedTimestamp: metav1.NewTime(now),
Tokens: oidcclient.Token{
RefreshToken: &oidcclient.RefreshToken{
Token: "test-refresh-token2",
},
},
},
},
}, input.normalized())
})
}
func TestLookup(t *testing.T) {
t.Parallel()
require.Nil(t, validSession.lookup(oidcclient.SessionCacheKey{}))
require.NotNil(t, validSession.lookup(oidcclient.SessionCacheKey{
Issuer: "test-issuer",
ClientID: "test-client-id",
Scopes: []string{"email", "offline_access", "openid", "profile"},
RedirectURI: "http://localhost:0/callback",
}))
}
func TestInsert(t *testing.T) {
t.Parallel()
c := emptySessionCache()
c.insert(sessionEntry{})
require.Len(t, c.Sessions, 1)
}

View File

@ -0,0 +1,146 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
// Package filesession implements a simple YAML file-based login.sessionCache.
package filesession
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"time"
"github.com/gofrs/flock"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/oidcclient"
)
const (
// defaultFileLockTimeout is how long we will wait trying to acquire the file lock on the session file before timing out.
defaultFileLockTimeout = 10 * time.Second
// defaultFileLockRetryInterval is how often we will poll while waiting for the file lock to become available.
defaultFileLockRetryInterval = 10 * time.Millisecond
)
// Option configures a cache in New().
type Option func(*Cache)
// WithErrorReporter is an Option that specifies a callback which will be invoked for each error reported during
// session cache operations. By default, these errors are silently ignored.
func WithErrorReporter(reporter func(error)) Option {
return func(c *Cache) {
c.errReporter = reporter
}
}
// New returns a login.SessionCache implementation backed by the specified file path.
func New(path string, options ...Option) *Cache {
lock := flock.New(path + ".lock")
c := Cache{
path: path,
trylockFunc: func() error {
ctx, cancel := context.WithTimeout(context.Background(), defaultFileLockTimeout)
defer cancel()
_, err := lock.TryLockContext(ctx, defaultFileLockRetryInterval)
return err
},
unlockFunc: lock.Unlock,
errReporter: func(_ error) {},
}
for _, opt := range options {
opt(&c)
}
return &c
}
type Cache struct {
path string
errReporter func(error)
trylockFunc func() error
unlockFunc func() error
}
// GetToken looks up the cached data for the given parameters. It may return nil if no valid matching session is cached.
func (c *Cache) GetToken(key oidcclient.SessionCacheKey) *oidcclient.Token {
// If the cache file does not exist, exit immediately with no error log
if _, err := os.Stat(c.path); errors.Is(err, os.ErrNotExist) {
return nil
}
// Read the cache and lookup the matching entry. If one exists, update its last used timestamp and return it.
var result *oidcclient.Token
c.withCache(func(cache *sessionCache) {
if entry := cache.lookup(key); entry != nil {
result = &entry.Tokens
entry.LastUsedTimestamp = metav1.Now()
}
})
return result
}
// PutToken stores the provided token into the session cache under the given parameters. It does not return an error
// but may silently fail to update the session cache.
func (c *Cache) PutToken(key oidcclient.SessionCacheKey, token *oidcclient.Token) {
// Create the cache directory if it does not exist.
if err := os.MkdirAll(filepath.Dir(c.path), 0700); err != nil && !errors.Is(err, os.ErrExist) {
c.errReporter(fmt.Errorf("could not create session cache directory: %w", err))
return
}
// Mutate the cache to upsert the new session entry.
c.withCache(func(cache *sessionCache) {
// Find the existing entry, if one exists
if match := cache.lookup(key); match != nil {
// Update the stored token.
match.Tokens = *token
match.LastUsedTimestamp = metav1.Now()
return
}
// If there's not an entry for this key, insert one.
now := metav1.Now()
cache.insert(sessionEntry{
Key: key,
CreationTimestamp: now,
LastUsedTimestamp: now,
Tokens: *token,
})
})
}
// withCache is an internal helper which locks, reads the cache, processes/mutates it with the provided function, then
// saves it back to the file.
func (c *Cache) withCache(transact func(*sessionCache)) {
// Grab the file lock so we have exclusive access to read the file.
if err := c.trylockFunc(); err != nil {
c.errReporter(fmt.Errorf("could not lock session file: %w", err))
return
}
// Unlock the file at the end of this call, bubbling up the error if things were otherwise successful.
defer func() {
if err := c.unlockFunc(); err != nil {
c.errReporter(fmt.Errorf("could not unlock session file: %w", err))
}
}()
// Try to read the existing cache.
cache, err := readSessionCache(c.path)
if err != nil {
// If that fails, fall back to resetting to a blank slate.
c.errReporter(fmt.Errorf("failed to read cache, resetting: %w", err))
cache = emptySessionCache()
}
// Process/mutate the session using the provided function.
transact(cache)
// Marshal the session back to YAML and save it to the file.
if err := cache.writeTo(c.path); err != nil {
c.errReporter(fmt.Errorf("could not write session cache: %w", err))
}
}

View File

@ -0,0 +1,455 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package filesession
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/oidcclient"
)
func TestNew(t *testing.T) {
t.Parallel()
tmp := t.TempDir() + "/sessions.yaml"
c := New(tmp)
require.NotNil(t, c)
require.Equal(t, tmp, c.path)
require.NotNil(t, c.errReporter)
c.errReporter(fmt.Errorf("some error"))
}
func TestGetToken(t *testing.T) {
t.Parallel()
now := time.Now().Round(1 * time.Second)
tests := []struct {
name string
makeTestFile func(t *testing.T, tmp string)
trylockFunc func(*testing.T) error
unlockFunc func(*testing.T) error
key oidcclient.SessionCacheKey
want *oidcclient.Token
wantErrors []string
wantTestFile func(t *testing.T, tmp string)
}{
{
name: "not found",
key: oidcclient.SessionCacheKey{},
},
{
name: "file lock error",
makeTestFile: func(t *testing.T, tmp string) { require.NoError(t, ioutil.WriteFile(tmp, []byte(""), 0600)) },
trylockFunc: func(t *testing.T) error { return fmt.Errorf("some lock error") },
unlockFunc: func(t *testing.T) error { require.Fail(t, "should not be called"); return nil },
key: oidcclient.SessionCacheKey{},
wantErrors: []string{"could not lock session file: some lock error"},
},
{
name: "invalid file",
makeTestFile: func(t *testing.T, tmp string) {
require.NoError(t, ioutil.WriteFile(tmp, []byte("invalid yaml"), 0600))
},
key: oidcclient.SessionCacheKey{},
wantErrors: []string{
"failed to read cache, resetting: invalid session file: error unmarshaling JSON: while decoding JSON: json: cannot unmarshal string into Go value of type filesession.sessionCache",
},
},
{
name: "invalid file, fail to unlock",
makeTestFile: func(t *testing.T, tmp string) { require.NoError(t, ioutil.WriteFile(tmp, []byte("invalid"), 0600)) },
trylockFunc: func(t *testing.T) error { return nil },
unlockFunc: func(t *testing.T) error { return fmt.Errorf("some unlock error") },
key: oidcclient.SessionCacheKey{},
wantErrors: []string{
"failed to read cache, resetting: invalid session file: error unmarshaling JSON: while decoding JSON: json: cannot unmarshal string into Go value of type filesession.sessionCache",
"could not unlock session file: some unlock error",
},
},
{
name: "unreadable file",
makeTestFile: func(t *testing.T, tmp string) {
require.NoError(t, os.Mkdir(tmp, 0700))
},
key: oidcclient.SessionCacheKey{},
wantErrors: []string{
"failed to read cache, resetting: could not read session file: read TEMPFILE: is a directory",
"could not write session cache: open TEMPFILE: is a directory",
},
},
{
name: "valid file but cache miss",
makeTestFile: func(t *testing.T, tmp string) {
validCache := emptySessionCache()
validCache.insert(sessionEntry{
Key: oidcclient.SessionCacheKey{
Issuer: "not-the-test-issuer",
ClientID: "not-the-test-client-id",
Scopes: []string{"email", "offline_access", "openid", "profile"},
RedirectURI: "http://localhost:0/callback",
},
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
Tokens: oidcclient.Token{
AccessToken: &oidcclient.AccessToken{
Token: "test-access-token",
Type: "Bearer",
Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
},
IDToken: &oidcclient.IDToken{
Token: "test-id-token",
Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
},
RefreshToken: &oidcclient.RefreshToken{
Token: "test-refresh-token",
},
},
})
require.NoError(t, validCache.writeTo(tmp))
},
key: oidcclient.SessionCacheKey{
Issuer: "test-issuer",
ClientID: "test-client-id",
Scopes: []string{"email", "offline_access", "openid", "profile"},
RedirectURI: "http://localhost:0/callback",
},
wantErrors: []string{},
},
{
name: "valid file with cache hit",
makeTestFile: func(t *testing.T, tmp string) {
validCache := emptySessionCache()
validCache.insert(sessionEntry{
Key: oidcclient.SessionCacheKey{
Issuer: "test-issuer",
ClientID: "test-client-id",
Scopes: []string{"email", "offline_access", "openid", "profile"},
RedirectURI: "http://localhost:0/callback",
},
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
Tokens: oidcclient.Token{
AccessToken: &oidcclient.AccessToken{
Token: "test-access-token",
Type: "Bearer",
Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
},
IDToken: &oidcclient.IDToken{
Token: "test-id-token",
Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
},
RefreshToken: &oidcclient.RefreshToken{
Token: "test-refresh-token",
},
},
})
require.NoError(t, validCache.writeTo(tmp))
},
key: oidcclient.SessionCacheKey{
Issuer: "test-issuer",
ClientID: "test-client-id",
Scopes: []string{"email", "offline_access", "openid", "profile"},
RedirectURI: "http://localhost:0/callback",
},
wantErrors: []string{},
want: &oidcclient.Token{
AccessToken: &oidcclient.AccessToken{
Token: "test-access-token",
Type: "Bearer",
Expiry: metav1.NewTime(now.Add(1 * time.Hour).Local()),
},
IDToken: &oidcclient.IDToken{
Token: "test-id-token",
Expiry: metav1.NewTime(now.Add(1 * time.Hour).Local()),
},
RefreshToken: &oidcclient.RefreshToken{
Token: "test-refresh-token",
},
},
wantTestFile: func(t *testing.T, tmp string) {
cache, err := readSessionCache(tmp)
require.NoError(t, err)
require.Len(t, cache.Sessions, 1)
require.Less(t, time.Since(cache.Sessions[0].LastUsedTimestamp.Time).Nanoseconds(), (5 * time.Second).Nanoseconds())
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
tmp := t.TempDir() + "/sessions.yaml"
if tt.makeTestFile != nil {
tt.makeTestFile(t, tmp)
}
// Initialize a cache with a reporter that collects errors
errors := errorCollector{t: t}
c := New(tmp, errors.collect())
if tt.trylockFunc != nil {
c.trylockFunc = func() error { return tt.trylockFunc(t) }
}
if tt.unlockFunc != nil {
c.unlockFunc = func() error { return tt.unlockFunc(t) }
}
got := c.GetToken(tt.key)
require.Equal(t, tt.want, got)
errors.require(tt.wantErrors, "TEMPFILE", tmp)
if tt.wantTestFile != nil {
tt.wantTestFile(t, tmp)
}
})
}
}
func TestPutToken(t *testing.T) {
t.Parallel()
now := time.Now().Round(1 * time.Second)
tests := []struct {
name string
makeTestFile func(t *testing.T, tmp string)
key oidcclient.SessionCacheKey
token *oidcclient.Token
wantErrors []string
wantTestFile func(t *testing.T, tmp string)
}{
{
name: "fail to create directory",
makeTestFile: func(t *testing.T, tmp string) {
require.NoError(t, ioutil.WriteFile(filepath.Dir(tmp), []byte{}, 0600))
},
wantErrors: []string{
"could not create session cache directory: mkdir TEMPDIR: not a directory",
},
},
{
name: "update to existing entry",
makeTestFile: func(t *testing.T, tmp string) {
validCache := emptySessionCache()
validCache.insert(sessionEntry{
Key: oidcclient.SessionCacheKey{
Issuer: "test-issuer",
ClientID: "test-client-id",
Scopes: []string{"email", "offline_access", "openid", "profile"},
RedirectURI: "http://localhost:0/callback",
},
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
Tokens: oidcclient.Token{
AccessToken: &oidcclient.AccessToken{
Token: "old-access-token",
Type: "Bearer",
Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
},
IDToken: &oidcclient.IDToken{
Token: "old-id-token",
Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
},
RefreshToken: &oidcclient.RefreshToken{
Token: "old-refresh-token",
},
},
})
require.NoError(t, os.MkdirAll(filepath.Dir(tmp), 0700))
require.NoError(t, validCache.writeTo(tmp))
},
key: oidcclient.SessionCacheKey{
Issuer: "test-issuer",
ClientID: "test-client-id",
Scopes: []string{"email", "offline_access", "openid", "profile"},
RedirectURI: "http://localhost:0/callback",
},
token: &oidcclient.Token{
AccessToken: &oidcclient.AccessToken{
Token: "new-access-token",
Type: "Bearer",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
},
IDToken: &oidcclient.IDToken{
Token: "new-id-token",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
},
RefreshToken: &oidcclient.RefreshToken{
Token: "new-refresh-token",
},
},
wantTestFile: func(t *testing.T, tmp string) {
cache, err := readSessionCache(tmp)
require.NoError(t, err)
require.Len(t, cache.Sessions, 1)
require.Less(t, time.Since(cache.Sessions[0].LastUsedTimestamp.Time).Nanoseconds(), (5 * time.Second).Nanoseconds())
require.Equal(t, oidcclient.Token{
AccessToken: &oidcclient.AccessToken{
Token: "new-access-token",
Type: "Bearer",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
},
IDToken: &oidcclient.IDToken{
Token: "new-id-token",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
},
RefreshToken: &oidcclient.RefreshToken{
Token: "new-refresh-token",
},
}, cache.Sessions[0].Tokens)
},
},
{
name: "new entry",
makeTestFile: func(t *testing.T, tmp string) {
validCache := emptySessionCache()
validCache.insert(sessionEntry{
Key: oidcclient.SessionCacheKey{
Issuer: "not-the-test-issuer",
ClientID: "not-the-test-client-id",
Scopes: []string{"email", "offline_access", "openid", "profile"},
RedirectURI: "http://localhost:0/callback",
},
CreationTimestamp: metav1.NewTime(now.Add(-2 * time.Hour)),
LastUsedTimestamp: metav1.NewTime(now.Add(-1 * time.Hour)),
Tokens: oidcclient.Token{
AccessToken: &oidcclient.AccessToken{
Token: "old-access-token",
Type: "Bearer",
Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
},
IDToken: &oidcclient.IDToken{
Token: "old-id-token",
Expiry: metav1.NewTime(now.Add(1 * time.Hour)),
},
RefreshToken: &oidcclient.RefreshToken{
Token: "old-refresh-token",
},
},
})
require.NoError(t, os.MkdirAll(filepath.Dir(tmp), 0700))
require.NoError(t, validCache.writeTo(tmp))
},
key: oidcclient.SessionCacheKey{
Issuer: "test-issuer",
ClientID: "test-client-id",
Scopes: []string{"email", "offline_access", "openid", "profile"},
RedirectURI: "http://localhost:0/callback",
},
token: &oidcclient.Token{
AccessToken: &oidcclient.AccessToken{
Token: "new-access-token",
Type: "Bearer",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
},
IDToken: &oidcclient.IDToken{
Token: "new-id-token",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
},
RefreshToken: &oidcclient.RefreshToken{
Token: "new-refresh-token",
},
},
wantTestFile: func(t *testing.T, tmp string) {
cache, err := readSessionCache(tmp)
require.NoError(t, err)
require.Len(t, cache.Sessions, 2)
require.Less(t, time.Since(cache.Sessions[1].LastUsedTimestamp.Time).Nanoseconds(), (5 * time.Second).Nanoseconds())
require.Equal(t, oidcclient.Token{
AccessToken: &oidcclient.AccessToken{
Token: "new-access-token",
Type: "Bearer",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
},
IDToken: &oidcclient.IDToken{
Token: "new-id-token",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
},
RefreshToken: &oidcclient.RefreshToken{
Token: "new-refresh-token",
},
}, cache.Sessions[1].Tokens)
},
},
{
name: "error writing cache",
makeTestFile: func(t *testing.T, tmp string) {
require.NoError(t, os.MkdirAll(tmp, 0700))
// require.NoError(t, emptySessionCache().writeTo(tmp))
// require.NoError(t, os.Chmod(tmp, 0400))
},
key: oidcclient.SessionCacheKey{
Issuer: "test-issuer",
ClientID: "test-client-id",
Scopes: []string{"email", "offline_access", "openid", "profile"},
RedirectURI: "http://localhost:0/callback",
},
token: &oidcclient.Token{
AccessToken: &oidcclient.AccessToken{
Token: "new-access-token",
Type: "Bearer",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
},
IDToken: &oidcclient.IDToken{
Token: "new-id-token",
Expiry: metav1.NewTime(now.Add(2 * time.Hour).Local()),
},
RefreshToken: &oidcclient.RefreshToken{
Token: "new-refresh-token",
},
},
wantErrors: []string{
"failed to read cache, resetting: could not read session file: read TEMPFILE: is a directory",
"could not write session cache: open TEMPFILE: is a directory",
},
wantTestFile: func(t *testing.T, tmp string) {
// cache, err := readSessionCache(tmp)
// require.NoError(t, err)
// require.Len(t, cache.Sessions, 0)
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
tmp := t.TempDir() + "/sessiondir/sessions.yaml"
if tt.makeTestFile != nil {
tt.makeTestFile(t, tmp)
}
// Initialize a cache with a reporter that collects errors
errors := errorCollector{t: t}
c := New(tmp, errors.collect())
c.PutToken(tt.key, tt.token)
errors.require(tt.wantErrors, "TEMPFILE", tmp, "TEMPDIR", filepath.Dir(tmp))
if tt.wantTestFile != nil {
tt.wantTestFile(t, tmp)
}
})
}
}
type errorCollector struct {
t *testing.T
saw []error
}
func (e *errorCollector) collect() Option {
return WithErrorReporter(func(err error) {
e.saw = append(e.saw, err)
})
}
func (e *errorCollector) require(want []string, subs ...string) {
require.Len(e.t, e.saw, len(want))
for i, w := range want {
for i := 0; i < len(subs); i += 2 {
w = strings.ReplaceAll(w, subs[i], subs[i+1])
}
require.EqualError(e.t, e.saw[i], w)
}
}

View File

@ -0,0 +1 @@
invalid YAML

View File

@ -0,0 +1,24 @@
apiVersion: config.pinniped.dev/v1alpha1
kind: SessionCache
sessions:
- creationTimestamp: "2020-10-20T18:42:07Z"
key:
clientID: test-client-id
issuer: test-issuer
redirect_uri: http://localhost:0/callback
scopes:
- email
- offline_access
- openid
- profile
lastUsedTimestamp: "2020-10-20T18:45:31Z"
tokens:
access:
expiryTimestamp: "2020-10-20T19:46:30Z"
token: test-access-token
type: Bearer
id:
expiryTimestamp: "2020-10-20T19:42:07Z"
token: test-id-token
refresh:
token: test-refresh-token

View File

@ -0,0 +1,3 @@
apiVersion: config.pinniped.dev/v2alpha6
kind: NotASessionCache
sessions: []

View File

@ -1,8 +1,8 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved. // Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
// Package login implements a CLI OIDC login flow. // Package oidcclient implements a CLI OIDC login flow.
package login package oidcclient
import ( import (
"context" "context"
@ -10,11 +10,13 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"sort"
"time" "time"
"github.com/coreos/go-oidc" "github.com/coreos/go-oidc"
"github.com/pkg/browser" "github.com/pkg/browser"
"golang.org/x/oauth2" "golang.org/x/oauth2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/httputil/securityheader" "go.pinniped.dev/internal/httputil/securityheader"
@ -23,12 +25,20 @@ import (
"go.pinniped.dev/internal/oidcclient/state" "go.pinniped.dev/internal/oidcclient/state"
) )
const (
// minIDTokenValidity is the minimum amount of time that a cached ID token must be still be valid to be considered.
// This is non-zero to ensure that most of the time, your ID token won't expire in the middle of a multi-step k8s
// API operation.
minIDTokenValidity = 10 * time.Minute
)
type handlerState struct { type handlerState struct {
// Basic parameters. // Basic parameters.
ctx context.Context ctx context.Context
issuer string issuer string
clientID string clientID string
scopes []string scopes []string
cache SessionCache
// Parameters of the localhost listener. // Parameters of the localhost listener.
listenAddr string listenAddr string
@ -55,13 +65,7 @@ type callbackResult struct {
err error err error
} }
type Token struct { // Option is an optional configuration for Login().
*oauth2.Token
IDToken string `json:"id_token"`
IDTokenExpiry time.Time `json:"id_token_expiry"`
}
// Option is an optional configuration for Run().
type Option func(*handlerState) error type Option func(*handlerState) error
// WithContext specifies a specific context.Context under which to perform the login. If this option is not specified, // WithContext specifies a specific context.Context under which to perform the login. If this option is not specified,
@ -105,13 +109,28 @@ func WithBrowserOpen(openURL func(url string) error) Option {
} }
} }
// Run an OAuth2/OIDC authorization code login using a localhost listener. // WithSessionCache sets the session cache backend for storing and retrieving previously-issued ID tokens and refresh tokens.
func Run(issuer string, clientID string, opts ...Option) (*Token, error) { func WithSessionCache(cache SessionCache) Option {
return func(h *handlerState) error {
h.cache = cache
return nil
}
}
// nopCache is a SessionCache that doesn't actually do anything.
type nopCache struct{}
func (*nopCache) GetToken(SessionCacheKey) *Token { return nil }
func (*nopCache) PutToken(SessionCacheKey, *Token) {}
// Login performs an OAuth2/OIDC authorization code login using a localhost listener.
func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
h := handlerState{ h := handlerState{
issuer: issuer, issuer: issuer,
clientID: clientID, clientID: clientID,
listenAddr: "localhost:0", listenAddr: "localhost:0",
scopes: []string{"offline_access", "openid", "email", "profile"}, scopes: []string{"offline_access", "openid", "email", "profile"},
cache: &nopCache{},
callbackPath: "/callback", callbackPath: "/callback",
ctx: context.Background(), ctx: context.Background(),
callbacks: make(chan callbackResult), callbacks: make(chan callbackResult),
@ -148,6 +167,22 @@ func Run(issuer string, clientID string, opts ...Option) (*Token, error) {
return nil, err return nil, err
} }
// Check the cache for a previous session issued with the same parameters.
sort.Strings(h.scopes)
cacheKey := SessionCacheKey{
Issuer: h.issuer,
ClientID: h.clientID,
Scopes: h.scopes,
RedirectURI: (&url.URL{Scheme: "http", Host: h.listenAddr, Path: h.callbackPath}).String(),
}
// If the ID token is still valid for a bit, return it immediately and skip the rest of the flow.
if cached := h.cache.GetToken(cacheKey); cached != nil &&
cached.IDToken != nil &&
time.Until(cached.IDToken.Expiry.Time) > minIDTokenValidity {
return cached, nil
}
// Perform OIDC discovery. // Perform OIDC discovery.
provider, err := oidc.NewProvider(h.ctx, h.issuer) provider, err := oidc.NewProvider(h.ctx, h.issuer)
if err != nil { if err != nil {
@ -209,6 +244,7 @@ func Run(issuer string, clientID string, opts ...Option) (*Token, error) {
if callback.err != nil { if callback.err != nil {
return nil, fmt.Errorf("error handling callback: %w", callback.err) return nil, fmt.Errorf("error handling callback: %w", callback.err)
} }
h.cache.PutToken(cacheKey, callback.token)
return callback.token, nil return callback.token, nil
} }
} }
@ -262,9 +298,18 @@ func (h *handlerState) handleAuthCodeCallback(w http.ResponseWriter, r *http.Req
} }
h.callbacks <- callbackResult{token: &Token{ h.callbacks <- callbackResult{token: &Token{
Token: oauth2Tok, AccessToken: &AccessToken{
IDToken: idTok, Token: oauth2Tok.AccessToken,
IDTokenExpiry: validated.Expiry, Type: oauth2Tok.TokenType,
Expiry: metav1.NewTime(oauth2Tok.Expiry),
},
RefreshToken: &RefreshToken{
Token: oauth2Tok.RefreshToken,
},
IDToken: &IDToken{
Token: idTok,
Expiry: metav1.NewTime(validated.Expiry),
},
}} }}
_, _ = w.Write([]byte("you have been logged in and may now close this tab")) _, _ = w.Write([]byte("you have been logged in and may now close this tab"))
return nil return nil

View File

@ -1,7 +1,7 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved. // Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
package login package oidcclient
import ( import (
"context" "context"
@ -18,6 +18,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"go.pinniped.dev/internal/httputil/httperr" "go.pinniped.dev/internal/httputil/httperr"
"go.pinniped.dev/internal/mocks/mockkeyset" "go.pinniped.dev/internal/mocks/mockkeyset"
@ -26,18 +27,42 @@ import (
"go.pinniped.dev/internal/oidcclient/state" "go.pinniped.dev/internal/oidcclient/state"
) )
func TestRun(t *testing.T) { // mockSessionCache exists to avoid an import cycle if we generate mocks into another package.
type mockSessionCache struct {
t *testing.T
getReturnsToken *Token
sawGetKeys []SessionCacheKey
sawPutKeys []SessionCacheKey
sawPutTokens []*Token
}
func (m *mockSessionCache) GetToken(key SessionCacheKey) *Token {
m.t.Logf("saw mock session cache GetToken() with client ID %s", key.ClientID)
m.sawGetKeys = append(m.sawGetKeys, key)
return m.getReturnsToken
}
func (m *mockSessionCache) PutToken(key SessionCacheKey, token *Token) {
m.t.Logf("saw mock session cache PutToken() with client ID %s and ID token %s", key.ClientID, token.IDToken.Token)
m.sawPutKeys = append(m.sawPutKeys, key)
m.sawPutTokens = append(m.sawPutTokens, token)
}
func TestLogin(t *testing.T) {
time1 := time.Date(3020, 10, 12, 13, 14, 15, 16, time.UTC) time1 := time.Date(3020, 10, 12, 13, 14, 15, 16, time.UTC)
testToken := Token{ testToken := Token{
Token: &oauth2.Token{ AccessToken: &AccessToken{
AccessToken: "test-access-token", Token: "test-access-token",
RefreshToken: "test-refresh-token", Expiry: metav1.NewTime(time1.Add(1 * time.Minute)),
Expiry: time1.Add(1 * time.Minute), },
RefreshToken: &RefreshToken{
Token: "test-refresh-token",
},
IDToken: &IDToken{
Token: "test-id-token",
Expiry: metav1.NewTime(time1.Add(2 * time.Minute)),
}, },
IDToken: "test-id-token",
IDTokenExpiry: time1.Add(2 * time.Minute),
} }
_ = testToken
// Start a test server that returns 500 errors // Start a test server that returns 500 errors
errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { errorServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -112,6 +137,53 @@ func TestRun(t *testing.T) {
}, },
wantErr: "some error generating PKCE", wantErr: "some error generating PKCE",
}, },
{
name: "session cache hit but token expired",
issuer: "test-issuer",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &Token{
IDToken: &IDToken{
Token: "test-id-token",
Expiry: metav1.NewTime(time.Now()), // less than Now() + minIDTokenValidity
},
}}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: "test-issuer",
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
return WithSessionCache(cache)(h)
}
},
wantErr: `could not perform OIDC discovery for "test-issuer": Get "test-issuer/.well-known/openid-configuration": unsupported protocol scheme ""`,
},
{
name: "session cache hit with valid token",
issuer: "test-issuer",
clientID: "test-client-id",
opt: func(t *testing.T) Option {
return func(h *handlerState) error {
cache := &mockSessionCache{t: t, getReturnsToken: &testToken}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{{
Issuer: "test-issuer",
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}}, cache.sawGetKeys)
require.Empty(t, cache.sawPutTokens)
})
return WithSessionCache(cache)(h)
}
},
wantToken: &testToken,
},
{ {
name: "discovery failure", name: "discovery failure",
opt: func(t *testing.T) Option { opt: func(t *testing.T) Option {
@ -183,6 +255,20 @@ func TestRun(t *testing.T) {
h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil } h.generatePKCE = func() (pkce.Code, error) { return "test-pkce", nil }
h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil } h.generateNonce = func() (nonce.Nonce, error) { return "test-nonce", nil }
cache := &mockSessionCache{t: t, getReturnsToken: nil}
cacheKey := SessionCacheKey{
Issuer: successServer.URL,
ClientID: "test-client-id",
Scopes: []string{"test-scope"},
RedirectURI: "http://localhost:0/callback",
}
t.Cleanup(func() {
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawGetKeys)
require.Equal(t, []SessionCacheKey{cacheKey}, cache.sawPutKeys)
require.Equal(t, []*Token{&testToken}, cache.sawPutTokens)
})
require.NoError(t, WithSessionCache(cache)(h))
h.openURL = func(actualURL string) error { h.openURL = func(actualURL string) error {
parsedActualURL, err := url.Parse(actualURL) parsedActualURL, err := url.Parse(actualURL)
require.NoError(t, err) require.NoError(t, err)
@ -223,7 +309,7 @@ func TestRun(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tok, err := Run(tt.issuer, tt.clientID, tok, err := Login(tt.issuer, tt.clientID,
WithContext(context.Background()), WithContext(context.Background()),
WithListenPort(0), WithListenPort(0),
WithScopes([]string{"test-scope"}), WithScopes([]string{"test-scope"}),
@ -393,7 +479,7 @@ func TestHandleAuthCodeCallback(t *testing.T) {
} }
require.NoError(t, result.err) require.NoError(t, result.err)
require.NotNil(t, result.token) require.NotNil(t, result.token)
require.Equal(t, result.token.IDToken, tt.returnIDTok) require.Equal(t, result.token.IDToken.Token, tt.returnIDTok)
} }
}) })
} }

View File

@ -0,0 +1,62 @@
// Copyright 2020 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package oidcclient
import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
// AccessToken is an OAuth2 access token.
type AccessToken struct {
// Token is the token that authorizes and authenticates the requests.
Token string `json:"token"`
// Type is the type of token.
Type string `json:"type,omitempty"`
// Expiry is the optional expiration time of the access token.
Expiry metav1.Time `json:"expiryTimestamp,omitempty"`
}
// RefreshToken is an OAuth2 refresh token.
type RefreshToken struct {
// Token is a token that's used by the application (as opposed to the user) to refresh the access token if it expires.
Token string `json:"token"`
}
// IDToken is an OpenID Connect ID token.
type IDToken struct {
// Token is an OpenID Connect ID token.
Token string `json:"token"`
// Expiry is the optional expiration time of the ID token.
Expiry metav1.Time `json:"expiryTimestamp,omitempty"`
}
// Token contains the elements of an OIDC session.
type Token struct {
// AccessToken is the token that authorizes and authenticates the requests.
AccessToken *AccessToken `json:"access,omitempty"`
// RefreshToken is a token that's used by the application
// (as opposed to the user) to refresh the access token
// if it expires.
RefreshToken *RefreshToken `json:"refresh,omitempty"`
// IDToken is an OpenID Connect ID token.
IDToken *IDToken `json:"id,omitempty"`
}
// SessionCacheKey contains the data used to select a valid session cache entry.
type SessionCacheKey struct {
Issuer string `json:"issuer"`
ClientID string `json:"clientID"`
Scopes []string `json:"scopes"`
RedirectURI string `json:"redirect_uri"`
}
type SessionCache interface {
GetToken(SessionCacheKey) *Token
PutToken(SessionCacheKey, *Token)
}

View File

@ -146,7 +146,7 @@ func getLoginProvider(t *testing.T) *loginProviderPatterns {
func TestCLILoginOIDC(t *testing.T) { func TestCLILoginOIDC(t *testing.T) {
env := library.IntegrationEnv(t) env := library.IntegrationEnv(t)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
// Find the login CSS selectors for the test issuer, or fail fast. // Find the login CSS selectors for the test issuer, or fail fast.
@ -172,12 +172,16 @@ func TestCLILoginOIDC(t *testing.T) {
t.Logf("building CLI binary") t.Logf("building CLI binary")
pinnipedExe := buildPinnipedCLI(t) pinnipedExe := buildPinnipedCLI(t)
// Make a temp directory to hold the session cache for this test.
sessionCachePath := t.TempDir() + "/sessions.yaml"
// Start the CLI running the "alpha login oidc [...]" command with stdout/stderr connected to pipes. // Start the CLI running the "alpha login oidc [...]" command with stdout/stderr connected to pipes.
t.Logf("starting CLI subprocess") t.Logf("starting CLI subprocess")
cmd := exec.CommandContext(ctx, pinnipedExe, "alpha", "login", "oidc", cmd := exec.CommandContext(ctx, pinnipedExe, "alpha", "login", "oidc",
"--issuer", env.OIDCUpstream.Issuer, "--issuer", env.OIDCUpstream.Issuer,
"--client-id", env.OIDCUpstream.ClientID, "--client-id", env.OIDCUpstream.ClientID,
"--listen-port", strconv.Itoa(env.OIDCUpstream.LocalhostPort), "--listen-port", strconv.Itoa(env.OIDCUpstream.LocalhostPort),
"--session-cache", sessionCachePath,
"--skip-browser", "--skip-browser",
) )
stderr, err := cmd.StderrPipe() stderr, err := cmd.StderrPipe()
@ -305,6 +309,26 @@ func TestCLILoginOIDC(t *testing.T) {
require.Equal(t, env.OIDCUpstream.ClientID, claims["aud"]) require.Equal(t, env.OIDCUpstream.ClientID, claims["aud"])
require.Equal(t, env.OIDCUpstream.Username, claims["email"]) require.Equal(t, env.OIDCUpstream.Username, claims["email"])
require.NotEmpty(t, claims["nonce"]) require.NotEmpty(t, claims["nonce"])
// Run the CLI again with the same session cache and login parameters.
t.Logf("starting second CLI subprocess to test session caching")
secondCtx, secondCancel := context.WithTimeout(ctx, 5*time.Second)
defer secondCancel()
cmdOutput, err := exec.CommandContext(secondCtx, pinnipedExe, "alpha", "login", "oidc",
"--issuer", env.OIDCUpstream.Issuer,
"--client-id", env.OIDCUpstream.ClientID,
"--listen-port", strconv.Itoa(env.OIDCUpstream.LocalhostPort),
"--session-cache", sessionCachePath,
"--skip-browser",
).CombinedOutput()
require.NoError(t, err)
// Expect the CLI to output the same ExecCredential in JSON format.
t.Logf("validating second ExecCredential")
var credOutput2 clientauthenticationv1beta1.ExecCredential
require.NoErrorf(t, json.Unmarshal(cmdOutput, &credOutput2),
"command returned something other than an ExecCredential:\n%s", string(cmdOutput))
require.Equal(t, credOutput, credOutput2)
} }
func waitForVisibleElements(t *testing.T, page *agouti.Page, selectors ...string) { func waitForVisibleElements(t *testing.T, page *agouti.Page, selectors ...string) {