Add --ca-bundle flag to "pinniped login oidc" command.

Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
Matt Moyer 2020-11-16 11:54:13 -06:00
parent e7ecfd3954
commit dd2133458e
No known key found for this signature in database
GPG Key ID: EAE88AD172C5AE2D
5 changed files with 43 additions and 1 deletions

View File

@ -60,7 +60,7 @@ issues:
linters-settings: linters-settings:
funlen: funlen:
lines: 125 lines: 150
statements: 50 statements: 50
goheader: goheader:
template: |- template: |-

View File

@ -4,7 +4,12 @@
package cmd package cmd
import ( import (
"crypto/tls"
"crypto/x509"
"encoding/json" "encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
@ -36,6 +41,7 @@ func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oid
scopes []string scopes []string
skipBrowser bool skipBrowser bool
sessionCachePath string sessionCachePath string
caBundlePaths []string
debugSessionCache bool debugSessionCache bool
) )
cmd.Flags().StringVar(&issuer, "issuer", "", "OpenID Connect issuer URL.") cmd.Flags().StringVar(&issuer, "issuer", "", "OpenID Connect issuer URL.")
@ -44,6 +50,7 @@ func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oid
cmd.Flags().StringSliceVar(&scopes, "scopes", []string{"offline_access", "openid", "email", "profile"}, "OIDC scopes to request during login.") cmd.Flags().StringSliceVar(&scopes, "scopes", []string{"offline_access", "openid", "email", "profile"}, "OIDC scopes to request during login.")
cmd.Flags().BoolVar(&skipBrowser, "skip-browser", false, "Skip opening the browser (just print the URL).") cmd.Flags().BoolVar(&skipBrowser, "skip-browser", false, "Skip opening the browser (just print the URL).")
cmd.Flags().StringVar(&sessionCachePath, "session-cache", filepath.Join(mustGetConfigDir(), "sessions.yaml"), "Path to session cache file.") cmd.Flags().StringVar(&sessionCachePath, "session-cache", filepath.Join(mustGetConfigDir(), "sessions.yaml"), "Path to session cache file.")
cmd.Flags().StringSliceVar(&caBundlePaths, "ca-bundle", nil, "Path to TLS certificate authority bundle (PEM format, optional, can be repeated).")
cmd.Flags().BoolVar(&debugSessionCache, "debug-session-cache", false, "Print debug logs related to the session cache.") cmd.Flags().BoolVar(&debugSessionCache, "debug-session-cache", false, "Print debug logs related to the session cache.")
mustMarkHidden(&cmd, "debug-session-cache") mustMarkHidden(&cmd, "debug-session-cache")
mustMarkRequired(&cmd, "issuer", "client-id") mustMarkRequired(&cmd, "issuer", "client-id")
@ -80,6 +87,27 @@ func oidcLoginCommand(loginFunc func(issuer string, clientID string, opts ...oid
})) }))
} }
if len(caBundlePaths) > 0 {
pool := x509.NewCertPool()
for _, p := range caBundlePaths {
pem, err := ioutil.ReadFile(p)
if err != nil {
return fmt.Errorf("could not read --ca-bundle: %w", err)
}
pool.AppendCertsFromPEM(pem)
}
tlsConfig := tls.Config{
RootCAs: pool,
MinVersion: tls.VersionTLS12,
}
opts = append(opts, oidcclient.WithClient(&http.Client{
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tlsConfig,
},
}))
}
tok, err := loginFunc(issuer, clientID, opts...) tok, err := loginFunc(issuer, clientID, opts...)
if err != nil { if err != nil {
return err return err

View File

@ -40,6 +40,7 @@ func TestLoginOIDCCommand(t *testing.T) {
oidc --issuer ISSUER --client-id CLIENT_ID [flags] oidc --issuer ISSUER --client-id CLIENT_ID [flags]
Flags: Flags:
--ca-bundle strings Path to TLS certificate authority bundle (PEM format, optional, can be repeated).
--client-id string OpenID Connect client ID. --client-id string OpenID Connect client ID.
-h, --help help for oidc -h, --help help for oidc
--issuer string OpenID Connect issuer URL. --issuer string OpenID Connect issuer URL.

View File

@ -44,6 +44,8 @@ type handlerState struct {
scopes []string scopes []string
cache SessionCache cache SessionCache
httpClient *http.Client
// Parameters of the localhost listener. // Parameters of the localhost listener.
listenAddr string listenAddr string
callbackPath string callbackPath string
@ -122,6 +124,14 @@ func WithSessionCache(cache SessionCache) Option {
} }
} }
// WithClient sets the HTTP client used to make CLI-to-provider requests.
func WithClient(httpClient *http.Client) Option {
return func(h *handlerState) error {
h.httpClient = httpClient
return nil
}
}
// nopCache is a SessionCache that doesn't actually do anything. // nopCache is a SessionCache that doesn't actually do anything.
type nopCache struct{} type nopCache struct{}
@ -144,6 +154,7 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
callbackPath: "/callback", callbackPath: "/callback",
ctx: context.Background(), ctx: context.Background(),
callbacks: make(chan callbackResult), callbacks: make(chan callbackResult),
httpClient: http.DefaultClient,
// Default implementations of external dependencies (to be mocked in tests). // Default implementations of external dependencies (to be mocked in tests).
generateState: state.Generate, generateState: state.Generate,
@ -163,6 +174,7 @@ func Login(issuer string, clientID string, opts ...Option) (*Token, error) {
// Always set a long, but non-infinite timeout for this operation. // Always set a long, but non-infinite timeout for this operation.
ctx, cancel := context.WithTimeout(h.ctx, 10*time.Minute) ctx, cancel := context.WithTimeout(h.ctx, 10*time.Minute)
defer cancel() defer cancel()
ctx = oidc.ClientContext(ctx, h.httpClient)
h.ctx = ctx h.ctx = ctx
// Initialize login parameters. // Initialize login parameters.

View File

@ -416,6 +416,7 @@ func TestLogin(t *testing.T) {
require.Equal(t, []*Token{&testToken}, cache.sawPutTokens) require.Equal(t, []*Token{&testToken}, cache.sawPutTokens)
}) })
require.NoError(t, WithSessionCache(cache)(h)) require.NoError(t, WithSessionCache(cache)(h))
require.NoError(t, WithClient(&http.Client{Timeout: 10 * time.Second})(h))
h.openURL = func(actualURL string) error { h.openURL = func(actualURL string) error {
parsedActualURL, err := url.Parse(actualURL) parsedActualURL, err := url.Parse(actualURL)