// Copyright 2020-2022 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

package certauthority

import (
	"crypto"
	"crypto/tls"
	"crypto/x509"
	"fmt"
	"io"
	"net"
	"os"
	"strings"
	"testing"
	"time"

	"github.com/stretchr/testify/require"

	"go.pinniped.dev/internal/testutil"
)

func loadFromFiles(t *testing.T, certPath string, keyPath string) (*CA, error) {
	t.Helper()

	certPEM, err := os.ReadFile(certPath)
	require.NoError(t, err)

	keyPEM, err := os.ReadFile(keyPath)
	require.NoError(t, err)

	ca, err := Load(string(certPEM), string(keyPEM))
	return ca, err
}

func TestLoad(t *testing.T) {
	tests := []struct {
		name     string
		certPath string
		keyPath  string
		wantErr  string
	}{
		{
			name:     "empty key",
			certPath: "./testdata/test.crt",
			keyPath:  "./testdata/empty",
			wantErr:  "could not load CA: tls: failed to find any PEM data in key input",
		},
		{
			name:     "invalid key",
			certPath: "./testdata/test.crt",
			keyPath:  "./testdata/invalid",
			wantErr:  "could not load CA: tls: failed to find any PEM data in key input",
		},
		{
			name:     "mismatched cert and key",
			certPath: "./testdata/test.crt",
			keyPath:  "./testdata/test2.key",
			wantErr:  "could not load CA: tls: private key does not match public key",
		},
		{
			name:     "multiple certs",
			certPath: "./testdata/multiple.crt",
			keyPath:  "./testdata/test.key",
			wantErr:  "invalid CA certificate: expected a single certificate, found 2 certificates",
		},
		{
			name:     "success",
			certPath: "./testdata/test.crt",
			keyPath:  "./testdata/test.key",
		},
	}
	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			ca, err := loadFromFiles(t, tt.certPath, tt.keyPath)
			if tt.wantErr != "" {
				require.EqualError(t, err, tt.wantErr)
				return
			}
			require.NoError(t, err)
			require.NotEmpty(t, ca.caCertBytes)
			require.NotNil(t, ca.signer)
			require.Nil(t, ca.privateKey) // this struct field is only used for CA's created by New()
		})
	}
}

func TestNew(t *testing.T) {
	now := time.Now()
	ca, err := New("Test CA", time.Minute)
	require.NoError(t, err)
	require.NotNil(t, ca)

	// Make sure the CA certificate looks roughly like what we expect.
	caCert, err := x509.ParseCertificate(ca.caCertBytes)
	require.NoError(t, err)
	require.Equal(t, "Test CA", caCert.Subject.CommonName)
	require.WithinDuration(t, now.Add(-5*time.Minute), caCert.NotBefore, 10*time.Second)
	require.WithinDuration(t, now.Add(time.Minute), caCert.NotAfter, 10*time.Second)

	require.NotNil(t, ca.privateKey)
}

func TestNewInternal(t *testing.T) {
	now := time.Date(2020, 7, 10, 12, 41, 12, 1234, time.UTC)

	tests := []struct {
		name           string
		ttl            time.Duration
		env            env
		wantErr        string
		wantCommonName string
		wantNotBefore  time.Time
		wantNotAfter   time.Time
	}{
		{
			name: "failed to generate CA serial",
			env: env{
				serialRNG:  strings.NewReader(""),
				keygenRNG:  strings.NewReader(""),
				signingRNG: strings.NewReader(""),
			},
			wantErr: "could not generate CA serial: EOF",
		},
		{
			name: "failed to generate CA key",
			env: env{
				serialRNG:  strings.NewReader(strings.Repeat("x", 64)),
				keygenRNG:  strings.NewReader(""),
				signingRNG: strings.NewReader(""),
			},
			wantErr: "could not generate CA private key: EOF",
		},
		{
			name: "failed to self-sign",
			env: env{
				serialRNG:  strings.NewReader(strings.Repeat("x", 64)),
				keygenRNG:  strings.NewReader(strings.Repeat("y", 64)),
				signingRNG: strings.NewReader(""),
				clock:      func() time.Time { return now },
			},
			wantErr: "could not issue CA certificate: EOF",
		},
		{
			name: "success",
			ttl:  time.Minute,
			env: env{
				serialRNG:  strings.NewReader(strings.Repeat("x", 64)),
				keygenRNG:  strings.NewReader(strings.Repeat("y", 64)),
				signingRNG: strings.NewReader(strings.Repeat("z", 64)),
				clock:      func() time.Time { return now },
			},
			wantCommonName: "Test CA",
			wantNotAfter:   now.Add(time.Minute),
			wantNotBefore:  now.Add(-5 * time.Minute),
		},
	}
	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			got, err := newInternal("Test CA", tt.ttl, tt.env)
			if tt.wantErr != "" {
				require.EqualError(t, err, tt.wantErr)
				require.Nil(t, got)
				return
			}
			require.NoError(t, err)
			require.NotNil(t, got)

			// Make sure the CA certificate looks roughly like what we expect.
			caCert, err := x509.ParseCertificate(got.caCertBytes)
			require.NoError(t, err)
			require.Equal(t, tt.wantCommonName, caCert.Subject.CommonName)
			require.Equal(t, tt.wantNotAfter.Unix(), caCert.NotAfter.Unix())
			require.Equal(t, tt.wantNotBefore.Unix(), caCert.NotBefore.Unix())
		})
	}
}

func TestBundle(t *testing.T) {
	ca := CA{caCertBytes: []byte{1, 2, 3, 4, 5, 6, 7, 8}}
	certPEM := ca.Bundle()
	require.Equal(t, "-----BEGIN CERTIFICATE-----\nAQIDBAUGBwg=\n-----END CERTIFICATE-----\n", string(certPEM))
}

func TestPrivateKeyToPEM(t *testing.T) {
	ca, err := New("Test CA", time.Hour)
	require.NoError(t, err)
	keyPEM, err := ca.PrivateKeyToPEM()
	require.NoError(t, err)
	require.Regexp(t, "(?s)-----BEGIN EC "+"PRIVATE KEY-----\n.*\n-----END EC PRIVATE KEY-----", string(keyPEM))
	certPEM := ca.Bundle()
	// Check that the public and private keys work together.
	_, err = tls.X509KeyPair(certPEM, keyPEM)
	require.NoError(t, err)

	reloaded, err := Load(string(certPEM), string(keyPEM))
	require.NoError(t, err)
	_, err = reloaded.PrivateKeyToPEM()
	require.EqualError(t, err, "no private key data (did you try to use this after Load?)")
}

func TestPool(t *testing.T) {
	ca, err := New("test", 1*time.Hour)
	require.NoError(t, err)

	pool := ca.Pool()
	require.Len(t, pool.Subjects(), 1)
}

type errSigner struct {
	pubkey crypto.PublicKey
	err    error
}

func (e *errSigner) Public() crypto.PublicKey { return e.pubkey }

func (e *errSigner) Sign(_ io.Reader, _ []byte, _ crypto.SignerOpts) ([]byte, error) {
	return nil, e.err
}

func TestIssue(t *testing.T) {
	const numRandBytes = 64 * 2 // each call to issue a cert will consume 64 bytes from the reader

	now := time.Date(2020, 7, 10, 12, 41, 12, 1234, time.UTC)

	realCA, err := loadFromFiles(t, "./testdata/test.crt", "./testdata/test.key")
	require.NoError(t, err)

	tests := []struct {
		name    string
		ca      CA
		wantErr string
	}{
		{
			name: "failed to generate serial",
			ca: CA{
				env: env{
					serialRNG: strings.NewReader(""),
				},
			},
			wantErr: "could not generate serial number for certificate: EOF",
		},
		{
			name: "failed to generate keypair",
			ca: CA{
				env: env{
					serialRNG: strings.NewReader(strings.Repeat("x", numRandBytes)),
					keygenRNG: strings.NewReader(""),
				},
			},
			wantErr: "could not generate private key: EOF",
		},
		{
			name: "invalid CA certificate",
			ca: CA{
				env: env{
					serialRNG: strings.NewReader(strings.Repeat("x", numRandBytes)),
					keygenRNG: strings.NewReader(strings.Repeat("x", numRandBytes)),
					clock:     func() time.Time { return now },
				},
			},
			wantErr: "could not parse CA certificate: x509: malformed certificate",
		},
		{
			name: "signing error",
			ca: CA{
				env: env{
					serialRNG: strings.NewReader(strings.Repeat("x", numRandBytes)),
					keygenRNG: strings.NewReader(strings.Repeat("x", numRandBytes)),
					clock:     func() time.Time { return now },
				},
				caCertBytes: realCA.caCertBytes,
				signer: &errSigner{
					pubkey: realCA.signer.Public(),
					err:    fmt.Errorf("some signer error"),
				},
			},
			wantErr: "could not sign certificate: some signer error",
		},
		{
			name: "parse certificate error",
			ca: CA{
				env: env{
					serialRNG: strings.NewReader(strings.Repeat("x", numRandBytes)),
					keygenRNG: strings.NewReader(strings.Repeat("x", numRandBytes)),
					clock:     func() time.Time { return now },
					parseCert: func(_ []byte) (*x509.Certificate, error) {
						return nil, fmt.Errorf("some parse certificate error")
					},
				},
				caCertBytes: realCA.caCertBytes,
				signer:      realCA.signer,
			},
			wantErr: "could not parse certificate: some parse certificate error",
		},
		{
			name: "success",
			ca: CA{
				env: env{
					serialRNG: strings.NewReader(strings.Repeat("x", numRandBytes)),
					keygenRNG: strings.NewReader(strings.Repeat("x", numRandBytes)),
					clock:     func() time.Time { return now },
					parseCert: x509.ParseCertificate,
				},
				caCertBytes: realCA.caCertBytes,
				signer:      realCA.signer,
			},
		},
	}
	for _, tt := range tests {
		tt := tt
		t.Run(tt.name, func(t *testing.T) {
			got, err := tt.ca.IssueServerCert([]string{"example.com"}, []net.IP{net.IPv4(1, 2, 3, 4)}, 10*time.Minute)
			if tt.wantErr != "" {
				require.EqualError(t, err, tt.wantErr)
				require.Nil(t, got)
			} else {
				require.NoError(t, err)
				require.NotNil(t, got)
			}
			got, err = tt.ca.IssueClientCert("test-user", []string{"group1", "group2"}, 10*time.Minute)
			if tt.wantErr != "" {
				require.EqualError(t, err, tt.wantErr)
				require.Nil(t, got)
			} else {
				require.NoError(t, err)
				require.NotNil(t, got)
			}
		})
	}
}

func TestToPEM(t *testing.T) {
	realCert, err := tls.LoadX509KeyPair("./testdata/test.crt", "./testdata/test.key")
	require.NoError(t, err)

	t.Run("error from input", func(t *testing.T) {
		certPEM, keyPEM, err := toPEM(nil, fmt.Errorf("some error"))
		require.EqualError(t, err, "some error")
		require.Nil(t, certPEM)
		require.Nil(t, keyPEM)
	})

	t.Run("invalid private key", func(t *testing.T) {
		cert := realCert
		cert.PrivateKey = nil
		certPEM, keyPEM, err := toPEM(&cert, nil)
		require.EqualError(t, err, "failed to marshal private key into PKCS8: x509: unknown key type while marshaling PKCS#8: <nil>")
		require.Nil(t, certPEM)
		require.Nil(t, keyPEM)
	})

	t.Run("success", func(t *testing.T) {
		certPEM, keyPEM, err := toPEM(&realCert, nil)
		require.NoError(t, err)
		require.NotEmpty(t, certPEM)
		require.NotEmpty(t, keyPEM)
	})
}

func TestIssueMethods(t *testing.T) {
	// One CA can be used to issue both kinds of certs.
	ca, err := New("Test CA", time.Hour)
	require.NoError(t, err)

	ttl := 121 * time.Hour

	t.Run("client certs", func(t *testing.T) {
		user := "test-username"
		groups := []string{"group1", "group2"}

		clientCert, err := ca.IssueClientCert(user, groups, ttl)
		require.NoError(t, err)
		certPEM, keyPEM, err := ToPEM(clientCert)
		require.NoError(t, err)
		validateClientCert(t, ca.Bundle(), certPEM, keyPEM, user, groups, ttl)

		certPEM, keyPEM, err = ca.IssueClientCertPEM(user, groups, ttl)
		require.NoError(t, err)
		validateClientCert(t, ca.Bundle(), certPEM, keyPEM, user, groups, ttl)

		certPEM, keyPEM, err = ca.IssueClientCertPEM(user, nil, ttl)
		require.NoError(t, err)
		validateClientCert(t, ca.Bundle(), certPEM, keyPEM, user, nil, ttl)

		certPEM, keyPEM, err = ca.IssueClientCertPEM(user, []string{}, ttl)
		require.NoError(t, err)
		validateClientCert(t, ca.Bundle(), certPEM, keyPEM, user, nil, ttl)

		certPEM, keyPEM, err = ca.IssueClientCertPEM("", []string{}, ttl)
		require.NoError(t, err)
		validateClientCert(t, ca.Bundle(), certPEM, keyPEM, "", nil, ttl)
	})

	t.Run("server certs", func(t *testing.T) {
		dnsNames := []string{"example.com", "pinniped.dev"}
		ips := []net.IP{net.ParseIP("127.0.0.1"), net.ParseIP("1.2.3.4")}

		serverCert, err := ca.IssueServerCert(dnsNames, ips, ttl)
		require.NoError(t, err)
		certPEM, keyPEM, err := ToPEM(serverCert)
		require.NoError(t, err)
		validateServerCert(t, ca.Bundle(), certPEM, keyPEM, dnsNames, ips, ttl)

		certPEM, keyPEM, err = ca.IssueServerCertPEM(dnsNames, ips, ttl)
		require.NoError(t, err)
		validateServerCert(t, ca.Bundle(), certPEM, keyPEM, dnsNames, ips, ttl)

		certPEM, keyPEM, err = ca.IssueServerCertPEM(nil, ips, ttl)
		require.NoError(t, err)
		validateServerCert(t, ca.Bundle(), certPEM, keyPEM, nil, ips, ttl)

		certPEM, keyPEM, err = ca.IssueServerCertPEM(dnsNames, nil, ttl)
		require.NoError(t, err)
		validateServerCert(t, ca.Bundle(), certPEM, keyPEM, dnsNames, nil, ttl)

		certPEM, keyPEM, err = ca.IssueServerCertPEM([]string{}, ips, ttl)
		require.NoError(t, err)
		validateServerCert(t, ca.Bundle(), certPEM, keyPEM, nil, ips, ttl)

		certPEM, keyPEM, err = ca.IssueServerCertPEM(dnsNames, []net.IP{}, ttl)
		require.NoError(t, err)
		validateServerCert(t, ca.Bundle(), certPEM, keyPEM, dnsNames, nil, ttl)
	})
}

func validateClientCert(t *testing.T, caBundle []byte, certPEM []byte, keyPEM []byte, expectedUser string, expectedGroups []string, expectedTTL time.Duration) {
	const fudgeFactor = 10 * time.Second
	v := testutil.ValidateClientCertificate(t, string(caBundle), string(certPEM))
	v.RequireLifetime(time.Now(), time.Now().Add(expectedTTL), certBackdate+fudgeFactor)
	v.RequireMatchesPrivateKey(string(keyPEM))
	v.RequireCommonName(expectedUser)
	v.RequireOrganizations(expectedGroups)
	v.RequireEmptyDNSNames()
	v.RequireEmptyIPs()
}

func validateServerCert(t *testing.T, caBundle []byte, certPEM []byte, keyPEM []byte, expectedDNSNames []string, expectedIPs []net.IP, expectedTTL time.Duration) {
	const fudgeFactor = 10 * time.Second
	v := testutil.ValidateServerCertificate(t, string(caBundle), string(certPEM))
	v.RequireLifetime(time.Now(), time.Now().Add(expectedTTL), certBackdate+fudgeFactor)
	v.RequireMatchesPrivateKey(string(keyPEM))
	v.RequireCommonName("")
	v.RequireDNSNames(expectedDNSNames)
	v.RequireIPs(expectedIPs)
}