diff --git a/.golangci.yaml b/.golangci.yaml index a9970589..9477ae33 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -51,11 +51,12 @@ linters: issues: exclude-rules: - # exclude tests from function length and global linting to encourage table-based tests. + # exclude tests from some rules for things that are useful in a testing context. - path: _test\.go linters: - funlen - gochecknoglobals + - goerr113 linters-settings: funlen: @@ -66,4 +67,4 @@ linters-settings: Copyright 2020 VMware, Inc. SPDX-License-Identifier: Apache-2.0 goimports: - local-prefixes: github.com/suzerain-io \ No newline at end of file + local-prefixes: github.com/suzerain-io diff --git a/cmd/placeholder-name/app/app.go b/cmd/placeholder-name/app/app.go index ff0b0923..30a56fec 100644 --- a/cmd/placeholder-name/app/app.go +++ b/cmd/placeholder-name/app/app.go @@ -7,34 +7,50 @@ SPDX-License-Identifier: Apache-2.0 package app import ( + "context" + "crypto/tls" + "crypto/x509/pkix" + "errors" + "fmt" "io" "log" + "net" "net/http" + "time" "github.com/spf13/cobra" + "golang.org/x/sync/errgroup" + "github.com/suzerain-io/placeholder-name/internal/certauthority" "github.com/suzerain-io/placeholder-name/pkg/handlers" ) +// shutdownGracePeriod controls how long active connections are allowed to continue at shutdown. +const shutdownGracePeriod = 5 * time.Second + // App is an object that represents the placeholder-name application. type App struct { cmd *cobra.Command + // listen address for healthz serve + healthAddr string + + // listen address for main serve + mainAddr string + // runFunc runs the actual program, after the parsing of flags has been done. // // It is mostly a field for the sake of testing. - runFunc func(configPath string) + runFunc func(ctx context.Context, configPath string) error } // New constructs a new App with command line args, stdout and stderr. func New(args []string, stdout, stderr io.Writer) *App { a := &App{ - runFunc: func(configPath string) { - addr := ":8080" - log.Printf("Starting server on %v", addr) - log.Fatal(http.ListenAndServe(addr, handlers.New())) - }, + healthAddr: ":8080", + mainAddr: ":8443", } + a.runFunc = a.serve var configPath string cmd := &cobra.Command{ @@ -42,8 +58,8 @@ func New(args []string, stdout, stderr io.Writer) *App { Long: `placeholder-name provides a generic API for mapping an external credential from somewhere to an internal credential to be used for authenticating to the Kubernetes API.`, - Run: func(cmd *cobra.Command, args []string) { - a.runFunc(configPath) + RunE: func(cmd *cobra.Command, args []string) error { + return a.runFunc(context.Background(), configPath) }, Args: cobra.NoArgs, } @@ -68,3 +84,76 @@ authenticating to the Kubernetes API.`, func (a *App) Run() error { return a.cmd.Execute() } + +func (a *App) serve(ctx context.Context, configPath string) error { + ca, err := certauthority.New(pkix.Name{CommonName: "Placeholder CA"}) + if err != nil { + return fmt.Errorf("could not initialize CA: %w", err) + } + caBundle, err := ca.Bundle() + if err != nil { + return fmt.Errorf("could not read CA bundle: %w", err) + } + log.Printf("initialized CA bundle:\n%s", string(caBundle)) + + cert, err := ca.Issue( + pkix.Name{CommonName: "Placeholder Server"}, + []string{"placeholder-serve"}, + 24*365*time.Hour, + ) + if err != nil { + return fmt.Errorf("could not issue serving certificate: %w", err) + } + + // Start an errgroup to manage the lifetimes of the various listener goroutines. + eg, ctx := errgroup.WithContext(ctx) + + // Start healthz listener + eg.Go(func() error { + log.Printf("Starting healthz serve on %v", a.healthAddr) + server := http.Server{ + BaseContext: func(_ net.Listener) context.Context { return ctx }, + Addr: a.healthAddr, + Handler: handlers.New(), + } + return runGracefully(ctx, &server, eg) + }) + + // Start main service listener + eg.Go(func() error { + log.Printf("Starting main serve on %v", a.mainAddr) + server := http.Server{ + BaseContext: func(_ net.Listener) context.Context { return ctx }, + Addr: a.mainAddr, + TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{*cert}, + }, + Handler: http.HandlerFunc(exampleHandler), + } + return runGracefully(ctx, &server, eg) + }) + + if err := eg.Wait(); !errors.Is(err, http.ErrServerClosed) { + return err + } + return nil +} + +// exampleHandler is a stub to be replaced with our real server logic. +func exampleHandler(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("Hello world")) +} + +// runGracefully runs an http.Server with graceful shutdown. +func runGracefully(ctx context.Context, srv *http.Server, eg *errgroup.Group) error { + // Start the listener in a child goroutine. + eg.Go(srv.ListenAndServe) + + // If/when the context is canceled or times out, initiate shutting down the serve. + <-ctx.Done() + + shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGracePeriod) + defer cancel() + return srv.Shutdown(shutdownCtx) +} diff --git a/cmd/placeholder-name/app/app_test.go b/cmd/placeholder-name/app/app_test.go index 8ef82c51..220d15fd 100644 --- a/cmd/placeholder-name/app/app_test.go +++ b/cmd/placeholder-name/app/app_test.go @@ -7,9 +7,11 @@ package app import ( "bytes" + "context" "testing" + "time" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const knownGoodUsage = `Usage: @@ -54,17 +56,18 @@ func TestCommand(t *testing.T) { }, }, } - for _, theTest := range tests { - test := theTest // please the linter :'( + for _, test := range tests { + test := test t.Run(test.name, func(t *testing.T) { - expect := assert.New(t) + expect := require.New(t) stdout := bytes.NewBuffer([]byte{}) stderr := bytes.NewBuffer([]byte{}) configPaths := make([]string, 0, 1) - runFunc := func(configPath string) { + runFunc := func(ctx context.Context, configPath string) error { configPaths = append(configPaths, configPath) + return nil } a := New(test.args, stdout, stderr) @@ -72,9 +75,8 @@ func TestCommand(t *testing.T) { err := a.Run() if test.wantConfigPath != "" { - if expect.Equal(1, len(configPaths)) { - expect.Equal(test.wantConfigPath, configPaths[0]) - } + expect.Equal(1, len(configPaths)) + expect.Equal(test.wantConfigPath, configPaths[0]) } else { expect.Error(err) expect.Contains(stdout.String(), knownGoodUsage) @@ -82,3 +84,33 @@ func TestCommand(t *testing.T) { }) } } + +func TestServeApp(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + cancel() + + a := App{ + healthAddr: "127.0.0.1:0", + mainAddr: "127.0.0.1:8443", + } + err := a.serve(ctx, "some/path/to/config.yaml") + require.NoError(t, err) + }) + + t.Run("failure", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + a := App{ + healthAddr: "127.0.0.1:8081", + mainAddr: "127.0.0.1:8081", + } + err := a.serve(ctx, "some/path/to/config.yaml") + require.EqualError(t, err, "listen tcp 127.0.0.1:8081: bind: address already in use") + }) +} diff --git a/go.mod b/go.mod index a82db4e1..9fecb119 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/golangci/golangci-lint v1.28.1 github.com/spf13/cobra v1.0.0 github.com/stretchr/testify v1.6.1 + golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208 golang.org/x/tools v0.0.0-20200707134715-9e0a013e855f // indirect k8s.io/api v0.18.5 k8s.io/apimachinery v0.18.5 diff --git a/internal/certauthority/certauthority.go b/internal/certauthority/certauthority.go new file mode 100644 index 00000000..8026d8e6 --- /dev/null +++ b/internal/certauthority/certauthority.go @@ -0,0 +1,172 @@ +/* +Copyright 2020 VMware, Inc. +SPDX-License-Identifier: Apache-2.0 +*/ + +// Package certauthority implements a simple x509 certificate authority suitable for use in an aggregated API service. +package certauthority + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "io" + "math/big" + "time" +) + +// CA holds the state for a simple x509 certificate authority suitable for use in an aggregated API service. +type CA struct { + // secure random number generators for various steps (usually crypto/rand.Reader, but broken out here for tests). + serialRNG io.Reader + keygenRNG io.Reader + signingRNG io.Reader + + // clock tells the current time (usually time.Now(), but broken out here for tests). + clock func() time.Time + + // signer is the private key for the current CA. + signer crypto.Signer + + // caCert is the DER-encoded certificate for the current CA. + caCertBytes []byte +} + +// Option to pass when calling New. +type Option func(*CA) error + +func New(subject pkix.Name, opts ...Option) (*CA, error) { + // Initialize the result by starting with some defaults and applying any provided options. + ca := CA{ + serialRNG: rand.Reader, + keygenRNG: rand.Reader, + signingRNG: rand.Reader, + clock: time.Now, + } + for _, opt := range opts { + if err := opt(&ca); err != nil { + return nil, err + } + } + + // Generate a random serial for the CA + serialNumber, err := randomSerial(ca.serialRNG) + if err != nil { + return nil, fmt.Errorf("could not generate CA serial: %w", err) + } + + // Generate a new P256 keypair. + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), ca.keygenRNG) + if err != nil { + return nil, fmt.Errorf("could not generate CA private key: %w", err) + } + ca.signer = privateKey + + // Make a CA certificate valid for 100 years and backdated by one minute. + now := ca.clock() + notBefore := now.Add(-1 * time.Minute) + notAfter := now.Add(24 * time.Hour * 365 * 100) + + // Create CA cert template + caTemplate := x509.Certificate{ + SerialNumber: serialNumber, + Subject: subject, + NotBefore: notBefore, + NotAfter: notAfter, + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + // Self-sign the CA to get the DER certificate. + caCertBytes, err := x509.CreateCertificate(ca.signingRNG, &caTemplate, &caTemplate, &privateKey.PublicKey, privateKey) + if err != nil { + return nil, fmt.Errorf("could not issue CA certificate: %w", err) + } + ca.caCertBytes = caCertBytes + return &ca, nil +} + +// WriteBundle writes the current CA signing bundle in concatenated PEM format. +func (c *CA) WriteBundle(out io.Writer) error { + if err := pem.Encode(out, &pem.Block{Type: "CERTIFICATE", Bytes: c.caCertBytes}); err != nil { + return fmt.Errorf("could not encode CA certificate to PEM: %w", err) + } + return nil +} + +// Bundle returns the current CA signing bundle in concatenated PEM format. +func (c *CA) Bundle() ([]byte, error) { + var out bytes.Buffer + err := c.WriteBundle(&out) + return out.Bytes(), err +} + +// Issue a new server certificate for the given identity and duration. +func (c *CA) Issue(subject pkix.Name, dnsNames []string, ttl time.Duration) (*tls.Certificate, error) { + // Choose a random 128 bit serial number. + serialNumber, err := randomSerial(c.serialRNG) + if err != nil { + return nil, fmt.Errorf("could not generate serial number for certificate: %w", err) + } + + // Generate a new P256 keypair. + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), c.keygenRNG) + if err != nil { + return nil, fmt.Errorf("could not generate private key: %w", err) + } + + // Make a CA caCert valid for the requested TTL and backdated by one minute. + now := c.clock() + notBefore := now.Add(-1 * time.Minute) + notAfter := now.Add(ttl) + + // Parse the DER encoded certificate to get an x509.Certificate. + caCert, err := x509.ParseCertificate(c.caCertBytes) + if err != nil { + return nil, fmt.Errorf("could not parse CA certificate: %w", err) + } + + // Sign a cert, getting back the DER-encoded certificate bytes. + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: subject, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + IsCA: false, + DNSNames: dnsNames, + } + certBytes, err := x509.CreateCertificate(rand.Reader, &template, caCert, &privateKey.PublicKey, c.signer) + if err != nil { + return nil, fmt.Errorf("could not sign certificate: %w", err) + } + + // Parse the DER encoded certificate back out into an *x509.Certificate. + newCert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, fmt.Errorf("could not parse certificate: %w", err) + } + + // Return the new certificate. + return &tls.Certificate{ + Certificate: [][]byte{certBytes}, + Leaf: newCert, + PrivateKey: privateKey, + }, nil +} + +// randomSerial generates a random 128 bit serial number. +func randomSerial(rng io.Reader) (*big.Int, error) { + return rand.Int(rng, new(big.Int).Lsh(big.NewInt(1), 128)) +} diff --git a/internal/certauthority/certauthority_test.go b/internal/certauthority/certauthority_test.go new file mode 100644 index 00000000..44d99303 --- /dev/null +++ b/internal/certauthority/certauthority_test.go @@ -0,0 +1,220 @@ +/* +Copyright 2020 VMware, Inc. +SPDX-License-Identifier: Apache-2.0 +*/ + +package certauthority + +import ( + "bytes" + "crypto" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "io" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestNew(t *testing.T) { + now := time.Date(2020, 7, 10, 12, 41, 12, 1234, time.UTC) + + tests := []struct { + name string + opts []Option + wantErr string + }{ + { + name: "error option", + opts: []Option{func(ca *CA) error { + return fmt.Errorf("some error") + }}, + wantErr: "some error", + }, + { + name: "failed to generate CA serial", + opts: []Option{func(ca *CA) error { + ca.serialRNG = strings.NewReader("") + ca.keygenRNG = strings.NewReader("") + ca.signingRNG = strings.NewReader("") + return nil + }}, + wantErr: "could not generate CA serial: EOF", + }, + { + name: "failed to generate CA key", + opts: []Option{func(ca *CA) error { + ca.serialRNG = strings.NewReader(strings.Repeat("x", 64)) + ca.keygenRNG = strings.NewReader("") + return nil + }}, + wantErr: "could not generate CA private key: EOF", + }, + { + name: "failed to self-sign", + opts: []Option{func(ca *CA) error { + ca.serialRNG = strings.NewReader(strings.Repeat("x", 64)) + ca.keygenRNG = strings.NewReader(strings.Repeat("y", 64)) + ca.signingRNG = strings.NewReader("") + return nil + }}, + wantErr: "could not issue CA certificate: EOF", + }, + { + name: "success", + opts: []Option{func(ca *CA) error { + ca.serialRNG = strings.NewReader(strings.Repeat("x", 64)) + ca.keygenRNG = strings.NewReader(strings.Repeat("y", 64)) + ca.signingRNG = strings.NewReader(strings.Repeat("z", 64)) + ca.clock = func() time.Time { return now } + return nil + }}, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + got, err := New(pkix.Name{CommonName: "Test CA"}, tt.opts...) + if tt.wantErr != "" { + require.EqualError(t, err, tt.wantErr) + require.Nil(t, got) + return + } + require.NoError(t, err) + require.NotNil(t, got) + + // Make sure the CA certificate looks roughly like what we expect. + caCert, err := x509.ParseCertificate(got.caCertBytes) + require.NoError(t, err) + require.Equal(t, "Test CA", caCert.Subject.CommonName) + require.Equal(t, now.Add(100*365*24*time.Hour).Unix(), caCert.NotAfter.Unix()) + require.Equal(t, now.Add(-1*time.Minute).Unix(), caCert.NotBefore.Unix()) + }) + } +} + +type errWriter struct { + err error +} + +func (e *errWriter) Write(p []byte) (n int, err error) { return 0, e.err } + +func TestWriteBundle(t *testing.T) { + t.Run("error", func(t *testing.T) { + ca := CA{} + out := errWriter{fmt.Errorf("some error")} + require.EqualError(t, ca.WriteBundle(&out), "could not encode CA certificate to PEM: some error") + }) + + t.Run("empty", func(t *testing.T) { + ca := CA{} + var out bytes.Buffer + require.NoError(t, ca.WriteBundle(&out)) + require.Equal(t, "-----BEGIN CERTIFICATE-----\n-----END CERTIFICATE-----\n", out.String()) + }) + + t.Run("success", func(t *testing.T) { + ca := CA{caCertBytes: []byte{1, 2, 3, 4, 5, 6, 7, 8}} + var out bytes.Buffer + require.NoError(t, ca.WriteBundle(&out)) + require.Equal(t, "-----BEGIN CERTIFICATE-----\nAQIDBAUGBwg=\n-----END CERTIFICATE-----\n", out.String()) + }) +} + +func TestBundle(t *testing.T) { + t.Run("success", func(t *testing.T) { + ca := CA{caCertBytes: []byte{1, 2, 3, 4, 5, 6, 7, 8}} + got, err := ca.Bundle() + require.NoError(t, err) + require.Equal(t, "-----BEGIN CERTIFICATE-----\nAQIDBAUGBwg=\n-----END CERTIFICATE-----\n", string(got)) + }) +} + +type errSigner struct { + pubkey crypto.PublicKey + err error +} + +func (e *errSigner) Public() crypto.PublicKey { return e.pubkey } + +func (e *errSigner) Sign(_ io.Reader, _ []byte, _ crypto.SignerOpts) ([]byte, error) { + return nil, e.err +} + +func TestIssue(t *testing.T) { + now := time.Date(2020, 7, 10, 12, 41, 12, 1234, time.UTC) + + realCA, err := New(pkix.Name{CommonName: "Test CA"}) + require.NoError(t, err) + + tests := []struct { + name string + ca CA + wantErr string + }{ + { + name: "failed to generate serial", + ca: CA{ + serialRNG: strings.NewReader(""), + }, + wantErr: "could not generate serial number for certificate: EOF", + }, + { + name: "failed to generate keypair", + ca: CA{ + serialRNG: strings.NewReader(strings.Repeat("x", 64)), + keygenRNG: strings.NewReader(""), + }, + wantErr: "could not generate private key: EOF", + }, + { + name: "invalid CA certificate", + ca: CA{ + serialRNG: strings.NewReader(strings.Repeat("x", 64)), + keygenRNG: strings.NewReader(strings.Repeat("x", 64)), + clock: func() time.Time { return now }, + }, + wantErr: "could not parse CA certificate: asn1: syntax error: sequence truncated", + }, + { + name: "signing error", + ca: CA{ + serialRNG: strings.NewReader(strings.Repeat("x", 64)), + keygenRNG: strings.NewReader(strings.Repeat("x", 64)), + clock: func() time.Time { return now }, + caCertBytes: realCA.caCertBytes, + signer: &errSigner{ + pubkey: realCA.signer.Public(), + err: fmt.Errorf("some signer error"), + }, + }, + wantErr: "could not sign certificate: some signer error", + }, + { + name: "success", + ca: CA{ + serialRNG: strings.NewReader(strings.Repeat("x", 64)), + keygenRNG: strings.NewReader(strings.Repeat("x", 64)), + clock: func() time.Time { return now }, + caCertBytes: realCA.caCertBytes, + signer: realCA.signer, + }, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + got, err := tt.ca.Issue(pkix.Name{CommonName: "Test Server"}, []string{"example.com"}, 10*time.Minute) + if tt.wantErr != "" { + require.EqualError(t, err, tt.wantErr) + require.Nil(t, got) + return + } + require.NoError(t, err) + require.NotNil(t, got) + }) + } +}