// Copyright 2020-2021 the Pinniped contributors. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

// Package oidcupstreamwatcher implements a controller which watches OIDCIdentityProviders.
package oidcupstreamwatcher

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"encoding/base64"
	"fmt"
	"net/http"
	"net/url"
	"sort"
	"strings"
	"time"

	"github.com/coreos/go-oidc/v3/oidc"
	"github.com/go-logr/logr"
	"golang.org/x/oauth2"
	corev1 "k8s.io/api/core/v1"
	"k8s.io/apimachinery/pkg/api/equality"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
	"k8s.io/apimachinery/pkg/labels"
	"k8s.io/apimachinery/pkg/util/cache"
	corev1informers "k8s.io/client-go/informers/core/v1"

	"go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1"
	pinnipedclientset "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned"
	idpinformers "go.pinniped.dev/generated/latest/client/supervisor/informers/externalversions/idp/v1alpha1"
	"go.pinniped.dev/internal/constable"
	pinnipedcontroller "go.pinniped.dev/internal/controller"
	"go.pinniped.dev/internal/controller/conditionsutil"
	"go.pinniped.dev/internal/controller/supervisorconfig/upstreamwatchers"
	"go.pinniped.dev/internal/controllerlib"
	"go.pinniped.dev/internal/oidc/provider"
	"go.pinniped.dev/internal/upstreamoidc"
)

const (
	// Setup for the name of our controller in logs.
	oidcControllerName = "oidc-upstream-observer"

	// Constants related to the client credentials Secret.
	oidcClientSecretType corev1.SecretType = "secrets.pinniped.dev/oidc-client"

	clientIDDataKey     = "clientID"
	clientSecretDataKey = "clientSecret"

	// Constants related to the OIDC provider discovery cache. These do not affect the cache of JWKS.
	oidcValidatorCacheTTL = 15 * time.Minute

	// Constants related to conditions.
	typeClientCredentialsValid = "ClientCredentialsValid"
	typeOIDCDiscoverySucceeded = "OIDCDiscoverySucceeded"

	reasonUnreachable     = "Unreachable"
	reasonInvalidResponse = "InvalidResponse"

	// Errors that are generated by our reconcile process.
	errOIDCFailureStatus = constable.Error("OIDCIdentityProvider has a failing condition")
)

// UpstreamOIDCIdentityProviderICache is a thread safe cache that holds a list of validated upstream OIDC IDP configurations.
type UpstreamOIDCIdentityProviderICache interface {
	SetOIDCIdentityProviders([]provider.UpstreamOIDCIdentityProviderI)
}

// lruValidatorCache caches the *oidc.Provider associated with a particular issuer/TLS configuration.
type lruValidatorCache struct{ cache *cache.Expiring }

type lruValidatorCacheEntry struct {
	provider *oidc.Provider
	client   *http.Client
}

func (c *lruValidatorCache) getProvider(spec *v1alpha1.OIDCIdentityProviderSpec) (*oidc.Provider, *http.Client) {
	if result, ok := c.cache.Get(c.cacheKey(spec)); ok {
		entry := result.(*lruValidatorCacheEntry)
		return entry.provider, entry.client
	}
	return nil, nil
}

func (c *lruValidatorCache) putProvider(spec *v1alpha1.OIDCIdentityProviderSpec, provider *oidc.Provider, client *http.Client) {
	c.cache.Set(c.cacheKey(spec), &lruValidatorCacheEntry{provider: provider, client: client}, oidcValidatorCacheTTL)
}

func (c *lruValidatorCache) cacheKey(spec *v1alpha1.OIDCIdentityProviderSpec) interface{} {
	var key struct{ issuer, caBundle string }
	key.issuer = spec.Issuer
	if spec.TLS != nil {
		key.caBundle = spec.TLS.CertificateAuthorityData
	}
	return key
}

type oidcWatcherController struct {
	cache                        UpstreamOIDCIdentityProviderICache
	log                          logr.Logger
	client                       pinnipedclientset.Interface
	oidcIdentityProviderInformer idpinformers.OIDCIdentityProviderInformer
	secretInformer               corev1informers.SecretInformer
	validatorCache               interface {
		getProvider(*v1alpha1.OIDCIdentityProviderSpec) (*oidc.Provider, *http.Client)
		putProvider(*v1alpha1.OIDCIdentityProviderSpec, *oidc.Provider, *http.Client)
	}
}

// New instantiates a new controllerlib.Controller which will populate the provided UpstreamOIDCIdentityProviderICache.
func New(
	idpCache UpstreamOIDCIdentityProviderICache,
	client pinnipedclientset.Interface,
	oidcIdentityProviderInformer idpinformers.OIDCIdentityProviderInformer,
	secretInformer corev1informers.SecretInformer,
	log logr.Logger,
	withInformer pinnipedcontroller.WithInformerOptionFunc,
) controllerlib.Controller {
	c := oidcWatcherController{
		cache:                        idpCache,
		log:                          log.WithName(oidcControllerName),
		client:                       client,
		oidcIdentityProviderInformer: oidcIdentityProviderInformer,
		secretInformer:               secretInformer,
		validatorCache:               &lruValidatorCache{cache: cache.NewExpiring()},
	}
	return controllerlib.New(
		controllerlib.Config{Name: oidcControllerName, Syncer: &c},
		withInformer(
			oidcIdentityProviderInformer,
			pinnipedcontroller.MatchAnythingFilter(pinnipedcontroller.SingletonQueue()),
			controllerlib.InformerOption{},
		),
		withInformer(
			secretInformer,
			pinnipedcontroller.MatchAnySecretOfTypeFilter(oidcClientSecretType, pinnipedcontroller.SingletonQueue()),
			controllerlib.InformerOption{},
		),
	)
}

// Sync implements controllerlib.Syncer.
func (c *oidcWatcherController) Sync(ctx controllerlib.Context) error {
	actualUpstreams, err := c.oidcIdentityProviderInformer.Lister().List(labels.Everything())
	if err != nil {
		return fmt.Errorf("failed to list OIDCIdentityProviders: %w", err)
	}

	requeue := false
	validatedUpstreams := make([]provider.UpstreamOIDCIdentityProviderI, 0, len(actualUpstreams))
	for _, upstream := range actualUpstreams {
		valid := c.validateUpstream(ctx, upstream)
		if valid == nil {
			requeue = true
		} else {
			validatedUpstreams = append(validatedUpstreams, provider.UpstreamOIDCIdentityProviderI(valid))
		}
	}
	c.cache.SetOIDCIdentityProviders(validatedUpstreams)
	if requeue {
		return controllerlib.ErrSyntheticRequeue
	}
	return nil
}

// validateUpstream validates the provided v1alpha1.OIDCIdentityProvider and returns the validated configuration as a
// provider.UpstreamOIDCIdentityProvider. As a side effect, it also updates the status of the v1alpha1.OIDCIdentityProvider.
func (c *oidcWatcherController) validateUpstream(ctx controllerlib.Context, upstream *v1alpha1.OIDCIdentityProvider) *upstreamoidc.ProviderConfig {
	result := upstreamoidc.ProviderConfig{
		Name: upstream.Name,
		Config: &oauth2.Config{
			Scopes: computeScopes(upstream.Spec.AuthorizationConfig.AdditionalScopes),
		},
		UsernameClaim: upstream.Spec.Claims.Username,
		GroupsClaim:   upstream.Spec.Claims.Groups,
	}
	conditions := []*v1alpha1.Condition{
		c.validateSecret(upstream, &result),
		c.validateIssuer(ctx.Context, upstream, &result),
	}
	c.updateStatus(ctx.Context, upstream, conditions)

	valid := true
	log := c.log.WithValues("namespace", upstream.Namespace, "name", upstream.Name)
	for _, condition := range conditions {
		if condition.Status == v1alpha1.ConditionFalse {
			valid = false
			log.WithValues(
				"type", condition.Type,
				"reason", condition.Reason,
				"message", condition.Message,
			).Error(errOIDCFailureStatus, "found failing condition")
		}
	}
	if valid {
		return &result
	}
	return nil
}

// validateSecret validates the .spec.client.secretName field and returns the appropriate ClientCredentialsValid condition.
func (c *oidcWatcherController) validateSecret(upstream *v1alpha1.OIDCIdentityProvider, result *upstreamoidc.ProviderConfig) *v1alpha1.Condition {
	secretName := upstream.Spec.Client.SecretName

	// Fetch the Secret from informer cache.
	secret, err := c.secretInformer.Lister().Secrets(upstream.Namespace).Get(secretName)
	if err != nil {
		return &v1alpha1.Condition{
			Type:    typeClientCredentialsValid,
			Status:  v1alpha1.ConditionFalse,
			Reason:  upstreamwatchers.ReasonNotFound,
			Message: err.Error(),
		}
	}

	// Validate the secret .type field.
	if secret.Type != oidcClientSecretType {
		return &v1alpha1.Condition{
			Type:    typeClientCredentialsValid,
			Status:  v1alpha1.ConditionFalse,
			Reason:  upstreamwatchers.ReasonWrongType,
			Message: fmt.Sprintf("referenced Secret %q has wrong type %q (should be %q)", secretName, secret.Type, oidcClientSecretType),
		}
	}

	// Validate the secret .data field.
	clientID := secret.Data[clientIDDataKey]
	clientSecret := secret.Data[clientSecretDataKey]
	if len(clientID) == 0 || len(clientSecret) == 0 {
		return &v1alpha1.Condition{
			Type:    typeClientCredentialsValid,
			Status:  v1alpha1.ConditionFalse,
			Reason:  upstreamwatchers.ReasonMissingKeys,
			Message: fmt.Sprintf("referenced Secret %q is missing required keys %q", secretName, []string{clientIDDataKey, clientSecretDataKey}),
		}
	}

	// If everything is valid, update the result and set the condition to true.
	result.Config.ClientID = string(clientID)
	result.Config.ClientSecret = string(clientSecret)
	return &v1alpha1.Condition{
		Type:    typeClientCredentialsValid,
		Status:  v1alpha1.ConditionTrue,
		Reason:  upstreamwatchers.ReasonSuccess,
		Message: "loaded client credentials",
	}
}

// validateIssuer validates the .spec.issuer field, performs OIDC discovery, and returns the appropriate OIDCDiscoverySucceeded condition.
func (c *oidcWatcherController) validateIssuer(ctx context.Context, upstream *v1alpha1.OIDCIdentityProvider, result *upstreamoidc.ProviderConfig) *v1alpha1.Condition {
	// Get the provider and HTTP Client from cache if possible.
	discoveredProvider, httpClient := c.validatorCache.getProvider(&upstream.Spec)

	// If the provider does not exist in the cache, do a fresh discovery lookup and save to the cache.
	if discoveredProvider == nil {
		tlsConfig, err := getTLSConfig(upstream)
		if err != nil {
			return &v1alpha1.Condition{
				Type:    typeOIDCDiscoverySucceeded,
				Status:  v1alpha1.ConditionFalse,
				Reason:  upstreamwatchers.ReasonInvalidTLSConfig,
				Message: err.Error(),
			}
		}

		httpClient = &http.Client{
			Timeout: time.Minute,
			Transport: &http.Transport{
				Proxy:           http.ProxyFromEnvironment,
				TLSClientConfig: tlsConfig,
			},
		}

		discoveredProvider, err = oidc.NewProvider(oidc.ClientContext(ctx, httpClient), upstream.Spec.Issuer)
		if err != nil {
			const klogLevelTrace = 6
			c.log.V(klogLevelTrace).WithValues(
				"namespace", upstream.Namespace,
				"name", upstream.Name,
				"issuer", upstream.Spec.Issuer,
			).Error(err, "failed to perform OIDC discovery")
			return &v1alpha1.Condition{
				Type:    typeOIDCDiscoverySucceeded,
				Status:  v1alpha1.ConditionFalse,
				Reason:  reasonUnreachable,
				Message: fmt.Sprintf("failed to perform OIDC discovery against %q:\n%s", upstream.Spec.Issuer, truncateNonOIDCErr(err)),
			}
		}

		// Update the cache with the newly discovered value.
		c.validatorCache.putProvider(&upstream.Spec, discoveredProvider, httpClient)
	}

	// Parse out and validate the discovered authorize endpoint.
	authURL, err := url.Parse(discoveredProvider.Endpoint().AuthURL)
	if err != nil {
		return &v1alpha1.Condition{
			Type:    typeOIDCDiscoverySucceeded,
			Status:  v1alpha1.ConditionFalse,
			Reason:  reasonInvalidResponse,
			Message: fmt.Sprintf("failed to parse authorization endpoint URL: %v", err),
		}
	}
	if authURL.Scheme != "https" {
		return &v1alpha1.Condition{
			Type:    typeOIDCDiscoverySucceeded,
			Status:  v1alpha1.ConditionFalse,
			Reason:  reasonInvalidResponse,
			Message: fmt.Sprintf(`authorization endpoint URL scheme must be "https", not %q`, authURL.Scheme),
		}
	}

	// If everything is valid, update the result and set the condition to true.
	result.Config.Endpoint = discoveredProvider.Endpoint()
	result.Provider = discoveredProvider
	result.Client = httpClient
	return &v1alpha1.Condition{
		Type:    typeOIDCDiscoverySucceeded,
		Status:  v1alpha1.ConditionTrue,
		Reason:  upstreamwatchers.ReasonSuccess,
		Message: "discovered issuer configuration",
	}
}

func (c *oidcWatcherController) updateStatus(ctx context.Context, upstream *v1alpha1.OIDCIdentityProvider, conditions []*v1alpha1.Condition) {
	log := c.log.WithValues("namespace", upstream.Namespace, "name", upstream.Name)
	updated := upstream.DeepCopy()

	hadErrorCondition := conditionsutil.Merge(conditions, upstream.Generation, &updated.Status.Conditions, log)

	updated.Status.Phase = v1alpha1.PhaseReady
	if hadErrorCondition {
		updated.Status.Phase = v1alpha1.PhaseError
	}

	if equality.Semantic.DeepEqual(upstream, updated) {
		return
	}

	_, err := c.client.
		IDPV1alpha1().
		OIDCIdentityProviders(upstream.Namespace).
		UpdateStatus(ctx, updated, metav1.UpdateOptions{})
	if err != nil {
		log.Error(err, "failed to update status")
	}
}

func getTLSConfig(upstream *v1alpha1.OIDCIdentityProvider) (*tls.Config, error) {
	result := tls.Config{
		MinVersion: tls.VersionTLS12,
	}

	if upstream.Spec.TLS == nil || upstream.Spec.TLS.CertificateAuthorityData == "" {
		return &result, nil
	}

	bundle, err := base64.StdEncoding.DecodeString(upstream.Spec.TLS.CertificateAuthorityData)
	if err != nil {
		return nil, fmt.Errorf("spec.certificateAuthorityData is invalid: %w", err)
	}

	result.RootCAs = x509.NewCertPool()
	if !result.RootCAs.AppendCertsFromPEM(bundle) {
		return nil, fmt.Errorf("spec.certificateAuthorityData is invalid: %w", upstreamwatchers.ErrNoCertificates)
	}

	return &result, nil
}

func computeScopes(additionalScopes []string) []string {
	// First compute the unique set of scopes, including "openid" (de-duplicate).
	set := make(map[string]bool, len(additionalScopes)+1)
	set["openid"] = true
	for _, s := range additionalScopes {
		set[s] = true
	}

	// Then grab all the keys and sort them.
	scopes := make([]string, 0, len(set))
	for s := range set {
		scopes = append(scopes, s)
	}
	sort.Strings(scopes)
	return scopes
}

func truncateNonOIDCErr(err error) string {
	const max = 100
	msg := err.Error()

	if len(msg) <= max || strings.HasPrefix(msg, "oidc:") {
		return msg
	}

	return msg[:max] + fmt.Sprintf(" [truncated %d chars]", len(msg)-max)
}