Add a context parameter so we can enforce a timeout for the token exchange.

Signed-off-by: Matt Moyer <moyerm@vmware.com>
This commit is contained in:
Matt Moyer 2020-07-28 09:10:40 -05:00
parent 0ee4f0417d
commit 1a349bb609
3 changed files with 47 additions and 18 deletions

View File

@ -6,10 +6,12 @@ SPDX-License-Identifier: Apache-2.0
package main package main
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"os" "os"
"time"
"k8s.io/client-go/pkg/apis/clientauthentication" "k8s.io/client-go/pkg/apis/clientauthentication"
@ -18,7 +20,7 @@ import (
) )
func main() { func main() {
err := run(os.LookupEnv, client.ExchangeToken, os.Stdout) err := run(os.LookupEnv, client.ExchangeToken, os.Stdout, 30*time.Second)
if err != nil { if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "%s", err.Error()) _, _ = fmt.Fprintf(os.Stderr, "%s", err.Error())
os.Exit(1) os.Exit(1)
@ -26,11 +28,14 @@ func main() {
} }
type envGetter func(string) (string, bool) type envGetter func(string) (string, bool)
type tokenExchanger func(token, caBundle, apiEndpoint string) (*clientauthentication.ExecCredential, error) type tokenExchanger func(ctx context.Context, token, caBundle, apiEndpoint string) (*clientauthentication.ExecCredential, error)
const ErrMissingEnvVar = constable.Error("failed to login: environment variable not set") const ErrMissingEnvVar = constable.Error("failed to login: environment variable not set")
func run(envGetter envGetter, tokenExchanger tokenExchanger, outputWriter io.Writer) error { func run(envGetter envGetter, tokenExchanger tokenExchanger, outputWriter io.Writer, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
token, varExists := envGetter("PLACEHOLDER_NAME_TOKEN") token, varExists := envGetter("PLACEHOLDER_NAME_TOKEN")
if !varExists { if !varExists {
return envVarNotSetError("PLACEHOLDER_NAME_TOKEN") return envVarNotSetError("PLACEHOLDER_NAME_TOKEN")
@ -46,7 +51,7 @@ func run(envGetter envGetter, tokenExchanger tokenExchanger, outputWriter io.Wri
return envVarNotSetError("PLACEHOLDER_NAME_K8S_API_ENDPOINT") return envVarNotSetError("PLACEHOLDER_NAME_K8S_API_ENDPOINT")
} }
execCredential, err := tokenExchanger(token, caBundle, apiEndpoint) execCredential, err := tokenExchanger(ctx, token, caBundle, apiEndpoint)
if err != nil { if err != nil {
return fmt.Errorf("failed to login: %w", err) return fmt.Errorf("failed to login: %w", err)
} }

View File

@ -7,8 +7,10 @@ package main
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"testing" "testing"
"time"
"github.com/sclevine/spec" "github.com/sclevine/spec"
"github.com/sclevine/spec/report" "github.com/sclevine/spec/report"
@ -49,7 +51,7 @@ func TestRun(t *testing.T) {
"PLACEHOLDER_NAME_K8S_API_ENDPOINT": "a", "PLACEHOLDER_NAME_K8S_API_ENDPOINT": "a",
"PLACEHOLDER_NAME_CA_BUNDLE": "b", "PLACEHOLDER_NAME_CA_BUNDLE": "b",
} }
err := run(envGetter, tokenExchanger, buffer) err := run(envGetter, tokenExchanger, buffer, 30*time.Second)
require.EqualError(t, err, "failed to login: environment variable not set: PLACEHOLDER_NAME_TOKEN") require.EqualError(t, err, "failed to login: environment variable not set: PLACEHOLDER_NAME_TOKEN")
}) })
@ -58,7 +60,7 @@ func TestRun(t *testing.T) {
"PLACEHOLDER_NAME_K8S_API_ENDPOINT": "a", "PLACEHOLDER_NAME_K8S_API_ENDPOINT": "a",
"PLACEHOLDER_NAME_TOKEN": "b", "PLACEHOLDER_NAME_TOKEN": "b",
} }
err := run(envGetter, tokenExchanger, buffer) err := run(envGetter, tokenExchanger, buffer, 30*time.Second)
require.EqualError(t, err, "failed to login: environment variable not set: PLACEHOLDER_NAME_CA_BUNDLE") require.EqualError(t, err, "failed to login: environment variable not set: PLACEHOLDER_NAME_CA_BUNDLE")
}) })
@ -67,27 +69,27 @@ func TestRun(t *testing.T) {
"PLACEHOLDER_NAME_TOKEN": "a", "PLACEHOLDER_NAME_TOKEN": "a",
"PLACEHOLDER_NAME_CA_BUNDLE": "b", "PLACEHOLDER_NAME_CA_BUNDLE": "b",
} }
err := run(envGetter, tokenExchanger, buffer) err := run(envGetter, tokenExchanger, buffer, 30*time.Second)
require.EqualError(t, err, "failed to login: environment variable not set: PLACEHOLDER_NAME_K8S_API_ENDPOINT") require.EqualError(t, err, "failed to login: environment variable not set: PLACEHOLDER_NAME_K8S_API_ENDPOINT")
}) })
}, spec.Parallel()) }, spec.Parallel())
when("the token exchange fails", func() { when("the token exchange fails", func() {
it.Before(func() { it.Before(func() {
tokenExchanger = func(token, caBundle, apiEndpoint string) (*clientauthentication.ExecCredential, error) { tokenExchanger = func(ctx context.Context, token, caBundle, apiEndpoint string) (*clientauthentication.ExecCredential, error) {
return nil, fmt.Errorf("some error") return nil, fmt.Errorf("some error")
} }
}) })
it("returns an error", func() { it("returns an error", func() {
err := run(envGetter, tokenExchanger, buffer) err := run(envGetter, tokenExchanger, buffer, 30*time.Second)
require.EqualError(t, err, "failed to login: some error") require.EqualError(t, err, "failed to login: some error")
}) })
}, spec.Parallel()) }, spec.Parallel())
when("the JSON encoder fails", func() { when("the JSON encoder fails", func() {
it.Before(func() { it.Before(func() {
tokenExchanger = func(token, caBundle, apiEndpoint string) (*clientauthentication.ExecCredential, error) { tokenExchanger = func(ctx context.Context, token, caBundle, apiEndpoint string) (*clientauthentication.ExecCredential, error) {
return &clientauthentication.ExecCredential{ return &clientauthentication.ExecCredential{
Status: &clientauthentication.ExecCredentialStatus{Token: "some token"}, Status: &clientauthentication.ExecCredentialStatus{Token: "some token"},
}, nil }, nil
@ -95,16 +97,36 @@ func TestRun(t *testing.T) {
}) })
it("returns an error", func() { it("returns an error", func() {
err := run(envGetter, tokenExchanger, &errWriter{returnErr: fmt.Errorf("some IO error")}) err := run(envGetter, tokenExchanger, &errWriter{returnErr: fmt.Errorf("some IO error")}, 30*time.Second)
require.EqualError(t, err, "failed to marshal response to stdout: some IO error") require.EqualError(t, err, "failed to marshal response to stdout: some IO error")
}) })
}, spec.Parallel()) }, spec.Parallel())
when("the token exchange times out", func() {
it.Before(func() {
tokenExchanger = func(ctx context.Context, token, caBundle, apiEndpoint string) (*clientauthentication.ExecCredential, error) {
select {
case <-time.After(100 * time.Millisecond):
return &clientauthentication.ExecCredential{
Status: &clientauthentication.ExecCredentialStatus{Token: "some token"},
}, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
})
it("returns an error", func() {
err := run(envGetter, tokenExchanger, buffer, 1*time.Millisecond)
require.EqualError(t, err, "failed to login: context deadline exceeded")
})
}, spec.Parallel())
when("the token exchange succeeds", func() { when("the token exchange succeeds", func() {
var actualToken, actualCaBundle, actualAPIEndpoint string var actualToken, actualCaBundle, actualAPIEndpoint string
it.Before(func() { it.Before(func() {
tokenExchanger = func(token, caBundle, apiEndpoint string) (*clientauthentication.ExecCredential, error) { tokenExchanger = func(ctx context.Context, token, caBundle, apiEndpoint string) (*clientauthentication.ExecCredential, error) {
actualToken, actualCaBundle, actualAPIEndpoint = token, caBundle, apiEndpoint actualToken, actualCaBundle, actualAPIEndpoint = token, caBundle, apiEndpoint
return &clientauthentication.ExecCredential{ return &clientauthentication.ExecCredential{
Status: &clientauthentication.ExecCredentialStatus{Token: "some token"}, Status: &clientauthentication.ExecCredentialStatus{Token: "some token"},
@ -113,7 +135,7 @@ func TestRun(t *testing.T) {
}) })
it("writes the execCredential to the given writer", func() { it("writes the execCredential to the given writer", func() {
err := run(envGetter, tokenExchanger, buffer) err := run(envGetter, tokenExchanger, buffer, 30*time.Second)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, fakeEnv["PLACEHOLDER_NAME_TOKEN"], actualToken) require.Equal(t, fakeEnv["PLACEHOLDER_NAME_TOKEN"], actualToken)
require.Equal(t, fakeEnv["PLACEHOLDER_NAME_CA_BUNDLE"], actualCaBundle) require.Equal(t, fakeEnv["PLACEHOLDER_NAME_CA_BUNDLE"], actualCaBundle)

View File

@ -5,11 +5,13 @@ SPDX-License-Identifier: Apache-2.0
package client package client
import "k8s.io/client-go/pkg/apis/clientauthentication" import (
"context"
func ExchangeToken(token, caBundle, apiEndpoint string) (*clientauthentication.ExecCredential, error) { "k8s.io/client-go/pkg/apis/clientauthentication"
_ = token )
_ = caBundle
_ = apiEndpoint func ExchangeToken(ctx context.Context, token, caBundle, apiEndpoint string) (*clientauthentication.ExecCredential, error) {
_, _, _, _ = ctx, token, caBundle, apiEndpoint
return nil, nil return nil, nil
} }