Make it possible to compare transformation pipelines in unit tests

This commit is contained in:
Ryan Richard 2023-07-17 16:41:28 -07:00
parent c771328bb1
commit b89e6d9d93
5 changed files with 414 additions and 31 deletions

View File

@ -160,6 +160,7 @@ func (t *UsernameTransformation) compile(transformer *CELTransformer, consts *Tr
baseCompiledTransformation: &baseCompiledTransformation{ baseCompiledTransformation: &baseCompiledTransformation{
program: program, program: program,
consts: consts, consts: consts,
sourceExpr: t,
maxExpressionRuntime: transformer.maxExpressionRuntime, maxExpressionRuntime: transformer.maxExpressionRuntime,
}, },
}, nil }, nil
@ -174,6 +175,7 @@ func (t *GroupsTransformation) compile(transformer *CELTransformer, consts *Tran
baseCompiledTransformation: &baseCompiledTransformation{ baseCompiledTransformation: &baseCompiledTransformation{
program: program, program: program,
consts: consts, consts: consts,
sourceExpr: t,
maxExpressionRuntime: transformer.maxExpressionRuntime, maxExpressionRuntime: transformer.maxExpressionRuntime,
}, },
}, nil }, nil
@ -188,6 +190,7 @@ func (t *AllowAuthenticationPolicy) compile(transformer *CELTransformer, consts
baseCompiledTransformation: &baseCompiledTransformation{ baseCompiledTransformation: &baseCompiledTransformation{
program: program, program: program,
consts: consts, consts: consts,
sourceExpr: t,
maxExpressionRuntime: transformer.maxExpressionRuntime, maxExpressionRuntime: transformer.maxExpressionRuntime,
}, },
rejectedAuthenticationMessage: t.RejectedAuthenticationMessage, rejectedAuthenticationMessage: t.RejectedAuthenticationMessage,
@ -198,6 +201,7 @@ func (t *AllowAuthenticationPolicy) compile(transformer *CELTransformer, consts
type baseCompiledTransformation struct { type baseCompiledTransformation struct {
program cel.Program program cel.Program
consts *TransformationConstants consts *TransformationConstants
sourceExpr CELTransformation
maxExpressionRuntime time.Duration maxExpressionRuntime time.Duration
} }
@ -302,6 +306,23 @@ func (c *compiledAllowAuthenticationPolicy) Evaluate(ctx context.Context, userna
return result, nil return result, nil
} }
type CELTransformationSource struct {
Expr CELTransformation
Consts *TransformationConstants
}
func (c *compiledUsernameTransformation) Source() interface{} {
return &CELTransformationSource{Expr: c.sourceExpr, Consts: c.consts}
}
func (c *compiledGroupsTransformation) Source() interface{} {
return &CELTransformationSource{Expr: c.sourceExpr, Consts: c.consts}
}
func (c *compiledAllowAuthenticationPolicy) Source() interface{} {
return &CELTransformationSource{Expr: c.sourceExpr, Consts: c.consts}
}
func newEnv() (*cel.Env, error) { func newEnv() (*cel.Env, error) {
// Note that Kubernetes uses CEL in several places, which are helpful to see as an example of // Note that Kubernetes uses CEL in several places, which are helpful to see as an example of
// how to configure the CEL compiler for production usage. Examples: // how to configure the CEL compiler for production usage. Examples:

View File

@ -765,6 +765,7 @@ func TestTransformer(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
pipeline := idtransform.NewTransformationPipeline() pipeline := idtransform.NewTransformationPipeline()
expectedPipelineSource := []interface{}{}
for _, transform := range tt.transforms { for _, transform := range tt.transforms {
compiledTransform, err := transformer.CompileTransformation(transform, tt.consts) compiledTransform, err := transformer.CompileTransformation(transform, tt.consts)
@ -774,6 +775,15 @@ func TestTransformer(t *testing.T) {
} }
require.NoError(t, err, "got an unexpected compile error") require.NoError(t, err, "got an unexpected compile error")
pipeline.AppendTransformation(compiledTransform) pipeline.AppendTransformation(compiledTransform)
expectedTransformSource := &CELTransformationSource{
Expr: transform,
Consts: tt.consts,
}
if expectedTransformSource.Consts == nil {
expectedTransformSource.Consts = &TransformationConstants{}
}
expectedPipelineSource = append(expectedPipelineSource, expectedTransformSource)
} }
ctx := context.Background() ctx := context.Background()
@ -792,6 +802,8 @@ func TestTransformer(t *testing.T) {
require.Equal(t, tt.wantGroups, result.Groups) require.Equal(t, tt.wantGroups, result.Groups)
require.Equal(t, !tt.wantAuthRejected, result.AuthenticationAllowed, "AuthenticationAllowed had unexpected value") require.Equal(t, !tt.wantAuthRejected, result.AuthenticationAllowed, "AuthenticationAllowed had unexpected value")
require.Equal(t, tt.wantAuthRejectedMessage, result.RejectedAuthenticationMessage) require.Equal(t, tt.wantAuthRejectedMessage, result.RejectedAuthenticationMessage)
require.Equal(t, expectedPipelineSource, pipeline.Source())
}) })
} }
} }

View File

@ -17,6 +17,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/types"
coretesting "k8s.io/client-go/testing" coretesting "k8s.io/client-go/testing"
clocktesting "k8s.io/utils/clock/testing" clocktesting "k8s.io/utils/clock/testing"
"k8s.io/utils/pointer" "k8s.io/utils/pointer"
@ -25,6 +26,7 @@ import (
idpv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1" idpv1alpha1 "go.pinniped.dev/generated/latest/apis/supervisor/idp/v1alpha1"
pinnipedfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake" pinnipedfake "go.pinniped.dev/generated/latest/client/supervisor/clientset/versioned/fake"
pinnipedinformers "go.pinniped.dev/generated/latest/client/supervisor/informers/externalversions" pinnipedinformers "go.pinniped.dev/generated/latest/client/supervisor/informers/externalversions"
"go.pinniped.dev/internal/celtransformer"
"go.pinniped.dev/internal/controllerlib" "go.pinniped.dev/internal/controllerlib"
"go.pinniped.dev/internal/federationdomain/federationdomainproviders" "go.pinniped.dev/internal/federationdomain/federationdomainproviders"
"go.pinniped.dev/internal/here" "go.pinniped.dev/internal/here"
@ -1282,7 +1284,7 @@ func TestTestFederationDomainWatcherControllerSync(t *testing.T) {
Transforms: configv1alpha1.FederationDomainTransforms{ Transforms: configv1alpha1.FederationDomainTransforms{
Constants: []configv1alpha1.FederationDomainTransformsConstant{ Constants: []configv1alpha1.FederationDomainTransformsConstant{
{Name: "duplicate1", Type: "string", StringValue: "abc"}, {Name: "duplicate1", Type: "string", StringValue: "abc"},
{Name: "duplicate1", Type: "string", StringValue: "def"}, {Name: "duplicate1", Type: "stringList", StringListValue: []string{"def"}},
{Name: "duplicate1", Type: "string", StringValue: "efg"}, {Name: "duplicate1", Type: "string", StringValue: "efg"},
{Name: "duplicate2", Type: "string", StringValue: "123"}, {Name: "duplicate2", Type: "string", StringValue: "123"},
{Name: "duplicate2", Type: "string", StringValue: "456"}, {Name: "duplicate2", Type: "string", StringValue: "456"},
@ -1612,6 +1614,133 @@ func TestTestFederationDomainWatcherControllerSync(t *testing.T) {
), ),
}, },
}, },
{
name: "the federation domain has valid IDPs and transformations and examples",
inputObjects: []runtime.Object{
oidcIdentityProvider,
ldapIdentityProvider,
&configv1alpha1.FederationDomain{
ObjectMeta: metav1.ObjectMeta{Name: "config1", Namespace: namespace, Generation: 123},
Spec: configv1alpha1.FederationDomainSpec{
Issuer: "https://issuer1.com",
IdentityProviders: []configv1alpha1.FederationDomainIdentityProvider{
{
DisplayName: "name1",
ObjectRef: corev1.TypedLocalObjectReference{
APIGroup: pointer.String(apiGroupSupervisor),
Kind: "OIDCIdentityProvider",
Name: oidcIdentityProvider.Name,
},
Transforms: configv1alpha1.FederationDomainTransforms{
Expressions: []configv1alpha1.FederationDomainTransformsExpression{
{Type: "policy/v1", Expression: `username == "ryan" || username == "rejectMeWithDefaultMessage"`, Message: "only ryan allowed"},
{Type: "policy/v1", Expression: `username != "rejectMeWithDefaultMessage"`}, // no message specified
{Type: "username/v1", Expression: `"pre:" + username`},
{Type: "groups/v1", Expression: `groups.map(g, "pre:" + g)`},
},
Constants: []configv1alpha1.FederationDomainTransformsConstant{
{Name: "str", Type: "string", StringValue: "abc"},
{Name: "strL", Type: "stringList", StringListValue: []string{"def"}},
},
Examples: []configv1alpha1.FederationDomainTransformsExample{
{
Username: "ryan",
Groups: []string{"a", "b"},
Expects: configv1alpha1.FederationDomainTransformsExampleExpects{
Username: "pre:ryan",
Groups: []string{"pre:b", "pre:a"},
Rejected: false,
},
},
{
Username: "other",
Expects: configv1alpha1.FederationDomainTransformsExampleExpects{
Rejected: true,
Message: "only ryan allowed",
},
},
{
Username: "rejectMeWithDefaultMessage",
Expects: configv1alpha1.FederationDomainTransformsExampleExpects{
Rejected: true,
// Not specifying message is the same as expecting the default message.
},
},
{
Username: "rejectMeWithDefaultMessage",
Expects: configv1alpha1.FederationDomainTransformsExampleExpects{
Rejected: true,
Message: "Authentication was rejected by a configured policy", // this is the default message
},
},
},
},
},
{
DisplayName: "name2",
ObjectRef: corev1.TypedLocalObjectReference{
APIGroup: pointer.String(apiGroupSupervisor),
Kind: "LDAPIdentityProvider",
Name: ldapIdentityProvider.Name,
},
Transforms: configv1alpha1.FederationDomainTransforms{
Expressions: []configv1alpha1.FederationDomainTransformsExpression{
{Type: "username/v1", Expression: `"pre:" + username`},
},
Examples: []configv1alpha1.FederationDomainTransformsExample{
{
Username: "ryan",
Groups: []string{"a", "b"},
Expects: configv1alpha1.FederationDomainTransformsExampleExpects{
Username: "pre:ryan",
Groups: []string{"b", "a"},
Rejected: false,
},
},
},
},
},
},
},
},
},
wantFDIssuers: []*federationdomainproviders.FederationDomainIssuer{
federationDomainIssuerWithIDPs(t, "https://issuer1.com", []*federationdomainproviders.FederationDomainIdentityProvider{
{
DisplayName: "name1",
UID: oidcIdentityProvider.UID,
Transforms: newTransformationPipeline(t, &celtransformer.TransformationConstants{
StringConstants: map[string]string{"str": "abc"},
StringListConstants: map[string][]string{"strL": {"def"}},
},
&celtransformer.AllowAuthenticationPolicy{
Expression: `username == "ryan" || username == "rejectMeWithDefaultMessage"`,
RejectedAuthenticationMessage: "only ryan allowed",
},
&celtransformer.AllowAuthenticationPolicy{Expression: `username != "rejectMeWithDefaultMessage"`},
&celtransformer.UsernameTransformation{Expression: `"pre:" + username`},
&celtransformer.GroupsTransformation{Expression: `groups.map(g, "pre:" + g)`},
),
},
{
DisplayName: "name2",
UID: ldapIdentityProvider.UID,
Transforms: newTransformationPipeline(t, &celtransformer.TransformationConstants{},
&celtransformer.UsernameTransformation{Expression: `"pre:" + username`},
),
},
}),
},
wantStatusUpdates: []*configv1alpha1.FederationDomain{
expectedFederationDomainStatusUpdate(
&configv1alpha1.FederationDomain{
ObjectMeta: metav1.ObjectMeta{Name: "config1", Namespace: namespace, Generation: 123},
},
configv1alpha1.FederationDomainPhaseReady,
allHappyConditionsSuccess("https://issuer1.com", frozenMetav1Now, 123),
),
},
},
{ {
name: "the federation domain specifies illegal const type, which shouldn't really happen since the CRD validates it", name: "the federation domain specifies illegal const type, which shouldn't really happen since the CRD validates it",
inputObjects: []runtime.Object{ inputObjects: []runtime.Object{
@ -1719,7 +1848,12 @@ func TestTestFederationDomainWatcherControllerSync(t *testing.T) {
if tt.wantFDIssuers != nil { if tt.wantFDIssuers != nil {
require.True(t, federationDomainsSetter.SetFederationDomainsWasCalled) require.True(t, federationDomainsSetter.SetFederationDomainsWasCalled)
require.ElementsMatch(t, tt.wantFDIssuers, federationDomainsSetter.FederationDomainsReceived) // This is ugly, but we cannot test equality on compiled identity transformations because cel.Program
// cannot be compared for equality. This converts them to a type which can be tested for equality,
// which should be good enough for the purposes of this test.
require.ElementsMatch(t,
convertToComparableType(tt.wantFDIssuers),
convertToComparableType(federationDomainsSetter.FederationDomainsReceived))
} else { } else {
require.False(t, federationDomainsSetter.SetFederationDomainsWasCalled) require.False(t, federationDomainsSetter.SetFederationDomainsWasCalled)
} }
@ -1743,6 +1877,46 @@ func TestTestFederationDomainWatcherControllerSync(t *testing.T) {
} }
} }
type comparableFederationDomainIssuer struct {
issuer string
identityProviders []*comparableFederationDomainIdentityProvider
defaultIdentityProvider *comparableFederationDomainIdentityProvider
}
type comparableFederationDomainIdentityProvider struct {
DisplayName string
UID types.UID
TransformsSource []interface{}
}
func makeFederationDomainIdentityProviderComparable(fdi *federationdomainproviders.FederationDomainIdentityProvider) *comparableFederationDomainIdentityProvider {
if fdi == nil {
return nil
}
return &comparableFederationDomainIdentityProvider{
DisplayName: fdi.DisplayName,
UID: fdi.UID,
TransformsSource: fdi.Transforms.Source(),
}
}
func convertToComparableType(fdis []*federationdomainproviders.FederationDomainIssuer) []*comparableFederationDomainIssuer {
result := []*comparableFederationDomainIssuer{}
for _, fdi := range fdis {
comparableFDIs := make([]*comparableFederationDomainIdentityProvider, len(fdi.IdentityProviders()))
for _, idp := range fdi.IdentityProviders() {
comparableFDIs = append(comparableFDIs, makeFederationDomainIdentityProviderComparable(idp))
}
converted := &comparableFederationDomainIssuer{
issuer: fdi.Issuer(),
identityProviders: comparableFDIs,
defaultIdentityProvider: makeFederationDomainIdentityProviderComparable(fdi.DefaultIdentityProvider()),
}
result = append(result, converted)
}
return result
}
func expectedFederationDomainStatusUpdate( func expectedFederationDomainStatusUpdate(
fd *configv1alpha1.FederationDomain, fd *configv1alpha1.FederationDomain,
phase configv1alpha1.FederationDomainPhase, phase configv1alpha1.FederationDomainPhase,
@ -1789,3 +1963,116 @@ func sortFederationDomainsByName(federationDomains []*configv1alpha1.FederationD
return federationDomains[a].GetName() < federationDomains[b].GetName() return federationDomains[a].GetName() < federationDomains[b].GetName()
}) })
} }
func newTransformationPipeline(
t *testing.T,
consts *celtransformer.TransformationConstants,
transformations ...celtransformer.CELTransformation,
) *idtransform.TransformationPipeline {
pipeline := idtransform.NewTransformationPipeline()
compiler, err := celtransformer.NewCELTransformer(celTransformerMaxExpressionRuntime)
require.NoError(t, err)
if consts.StringConstants == nil {
consts.StringConstants = map[string]string{}
}
if consts.StringListConstants == nil {
consts.StringListConstants = map[string][]string{}
}
for _, transform := range transformations {
compiledTransform, err := compiler.CompileTransformation(transform, consts)
require.NoError(t, err)
pipeline.AppendTransformation(compiledTransform)
}
return pipeline
}
func TestTransformationPipelinesCanBeTestedForEqualityUsingSourceToMakeTestingEasier(t *testing.T) {
compiler, err := celtransformer.NewCELTransformer(5 * time.Second)
require.NoError(t, err)
transforms := []celtransformer.CELTransformation{
&celtransformer.AllowAuthenticationPolicy{
Expression: `username == "ryan" || username == "rejectMeWithDefaultMessage"`,
RejectedAuthenticationMessage: "only ryan allowed",
},
&celtransformer.UsernameTransformation{Expression: `"pre:" + username`},
&celtransformer.GroupsTransformation{Expression: `groups.map(g, "pre:" + g)`},
}
differentTransforms := []celtransformer.CELTransformation{
&celtransformer.AllowAuthenticationPolicy{
Expression: `username == "ryan" || username == "different"`,
RejectedAuthenticationMessage: "different",
},
&celtransformer.UsernameTransformation{Expression: `"different:" + username`},
&celtransformer.GroupsTransformation{Expression: `groups.map(g, "different:" + g)`},
}
consts := &celtransformer.TransformationConstants{
StringConstants: map[string]string{
"foo": "bar",
"baz": "bat",
},
StringListConstants: map[string][]string{
"foo": {"a", "b"},
"bar": {"c", "d"},
},
}
differentConsts := &celtransformer.TransformationConstants{
StringConstants: map[string]string{
"foo": "barDifferent",
"baz": "bat",
},
StringListConstants: map[string][]string{
"foo": {"aDifferent", "b"},
"bar": {"c", "d"},
},
}
pipeline := idtransform.NewTransformationPipeline()
equalPipeline := idtransform.NewTransformationPipeline()
differentPipeline1 := idtransform.NewTransformationPipeline()
differentPipeline2 := idtransform.NewTransformationPipeline()
expectedSourceList := []interface{}{}
for i, transform := range transforms {
// Compile and append to a pipeline.
compiledTransform1, err := compiler.CompileTransformation(transform, consts)
require.NoError(t, err)
pipeline.AppendTransformation(compiledTransform1)
// Recompile the same thing and append it to another pipeline.
// This pipeline should end up being equal to the first one.
compiledTransform2, err := compiler.CompileTransformation(transform, consts)
require.NoError(t, err)
equalPipeline.AppendTransformation(compiledTransform2)
// Build up a test expectation value.
expectedSourceList = append(expectedSourceList, &celtransformer.CELTransformationSource{Expr: transform, Consts: consts})
// Compile a different expression using the same constants and append it to a different pipeline.
// This should not be equal to the other pipelines.
compiledDifferentExpressionSameConsts, err := compiler.CompileTransformation(differentTransforms[i], consts)
require.NoError(t, err)
differentPipeline1.AppendTransformation(compiledDifferentExpressionSameConsts)
// Compile the same expression using the different constants and append it to a different pipeline.
// This should not be equal to the other pipelines.
compiledSameExpressionDifferentConsts, err := compiler.CompileTransformation(transform, differentConsts)
require.NoError(t, err)
differentPipeline2.AppendTransformation(compiledSameExpressionDifferentConsts)
}
require.Equal(t, expectedSourceList, pipeline.Source())
require.Equal(t, expectedSourceList, equalPipeline.Source())
// The source of compiled pipelines can be compared to each other in this way for testing purposes.
require.Equal(t, pipeline.Source(), equalPipeline.Source())
require.NotEqual(t, pipeline.Source(), differentPipeline1.Source())
require.NotEqual(t, pipeline.Source(), differentPipeline2.Source())
}

View File

@ -25,6 +25,10 @@ type TransformationResult struct {
// IdentityTransformation is an individual identity transformation which can be evaluated. // IdentityTransformation is an individual identity transformation which can be evaluated.
type IdentityTransformation interface { type IdentityTransformation interface {
Evaluate(ctx context.Context, username string, groups []string) (*TransformationResult, error) Evaluate(ctx context.Context, username string, groups []string) (*TransformationResult, error)
// Source returns some representation of the original source code of the transformation, which is
// useful for tests to be able to check that a compiled transformation came from the right source.
Source() interface{}
} }
// TransformationPipeline is a list of identity transforms, which can be evaluated in order against some given input // TransformationPipeline is a list of identity transforms, which can be evaluated in order against some given input
@ -85,6 +89,14 @@ func (p *TransformationPipeline) Evaluate(ctx context.Context, username string,
return accumulatedResult, nil return accumulatedResult, nil
} }
func (p *TransformationPipeline) Source() []interface{} {
result := []interface{}{}
for _, transform := range p.transforms {
result = append(result, transform.Source())
}
return result
}
func sortAndUniq(s []string) []string { func sortAndUniq(s []string) []string {
unique := sets.New(s...).UnsortedList() unique := sets.New(s...).UnsortedList()
sort.Strings(unique) sort.Strings(unique)

View File

@ -11,9 +11,9 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
type FakeNoopTransformer struct{} type fakeNoopTransformer struct{}
func (a FakeNoopTransformer) Evaluate(ctx context.Context, username string, groups []string) (*TransformationResult, error) { func (a fakeNoopTransformer) Evaluate(ctx context.Context, username string, groups []string) (*TransformationResult, error) {
return &TransformationResult{ return &TransformationResult{
Username: username, Username: username,
Groups: groups, Groups: groups,
@ -22,9 +22,13 @@ func (a FakeNoopTransformer) Evaluate(ctx context.Context, username string, grou
}, nil }, nil
} }
type FakeNilGroupTransformer struct{} func (a fakeNoopTransformer) Source() interface{} {
return nil // not needed for this test
}
func (a FakeNilGroupTransformer) Evaluate(ctx context.Context, username string, groups []string) (*TransformationResult, error) { type fakeNilGroupTransformer struct{}
func (a fakeNilGroupTransformer) Evaluate(ctx context.Context, username string, groups []string) (*TransformationResult, error) {
return &TransformationResult{ return &TransformationResult{
Username: username, Username: username,
Groups: nil, Groups: nil,
@ -33,9 +37,13 @@ func (a FakeNilGroupTransformer) Evaluate(ctx context.Context, username string,
}, nil }, nil
} }
type FakeAppendStringTransformer struct{} func (a fakeNilGroupTransformer) Source() interface{} {
return nil // not needed for this test
}
func (a FakeAppendStringTransformer) Evaluate(ctx context.Context, username string, groups []string) (*TransformationResult, error) { type fakeAppendStringTransformer struct{}
func (a fakeAppendStringTransformer) Evaluate(ctx context.Context, username string, groups []string) (*TransformationResult, error) {
newGroups := []string{} newGroups := []string{}
for _, group := range groups { for _, group := range groups {
newGroups = append(newGroups, group+":transformed") newGroups = append(newGroups, group+":transformed")
@ -48,9 +56,13 @@ func (a FakeAppendStringTransformer) Evaluate(ctx context.Context, username stri
}, nil }, nil
} }
type FakeDeleteUsernameAndGroupsTransformer struct{} func (a fakeAppendStringTransformer) Source() interface{} {
return nil // not needed for this test
}
func (d FakeDeleteUsernameAndGroupsTransformer) Evaluate(ctx context.Context, username string, groups []string) (*TransformationResult, error) { type fakeDeleteUsernameAndGroupsTransformer struct{}
func (a fakeDeleteUsernameAndGroupsTransformer) Evaluate(ctx context.Context, username string, groups []string) (*TransformationResult, error) {
return &TransformationResult{ return &TransformationResult{
Username: "", Username: "",
Groups: []string{}, Groups: []string{},
@ -59,9 +71,13 @@ func (d FakeDeleteUsernameAndGroupsTransformer) Evaluate(ctx context.Context, us
}, nil }, nil
} }
type FakeAuthenticationDisallowedTransformer struct{} func (a fakeDeleteUsernameAndGroupsTransformer) Source() interface{} {
return nil // not needed for this test
}
func (d FakeAuthenticationDisallowedTransformer) Evaluate(ctx context.Context, username string, groups []string) (*TransformationResult, error) { type fakeAuthenticationDisallowedTransformer struct{}
func (a fakeAuthenticationDisallowedTransformer) Evaluate(ctx context.Context, username string, groups []string) (*TransformationResult, error) {
newGroups := []string{} newGroups := []string{}
for _, group := range groups { for _, group := range groups {
newGroups = append(newGroups, group+":disallowed") newGroups = append(newGroups, group+":disallowed")
@ -74,13 +90,33 @@ func (d FakeAuthenticationDisallowedTransformer) Evaluate(ctx context.Context, u
}, nil }, nil
} }
type FakeErrorTransformer struct{} func (a fakeAuthenticationDisallowedTransformer) Source() interface{} {
return nil // not needed for this test
}
func (d FakeErrorTransformer) Evaluate(ctx context.Context, username string, groups []string) (*TransformationResult, error) { type fakeErrorTransformer struct{}
func (a fakeErrorTransformer) Evaluate(ctx context.Context, username string, groups []string) (*TransformationResult, error) {
return &TransformationResult{}, errors.New("unexpected catastrophic error") return &TransformationResult{}, errors.New("unexpected catastrophic error")
} }
func TestTransformationPipeline(t *testing.T) { func (a fakeErrorTransformer) Source() interface{} {
return nil // not needed for this test
}
type fakeTransformerWithSource struct {
source string
}
func (a fakeTransformerWithSource) Evaluate(ctx context.Context, username string, groups []string) (*TransformationResult, error) {
return nil, nil // not needed for this test
}
func (a fakeTransformerWithSource) Source() interface{} {
return a.source
}
func TestTransformationPipelineEvaluation(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
username string username string
@ -95,7 +131,7 @@ func TestTransformationPipeline(t *testing.T) {
{ {
name: "single transformation applied successfully", name: "single transformation applied successfully",
transforms: []IdentityTransformation{ transforms: []IdentityTransformation{
FakeAppendStringTransformer{}, fakeAppendStringTransformer{},
}, },
username: "foo", username: "foo",
groups: []string{ groups: []string{
@ -113,7 +149,7 @@ func TestTransformationPipeline(t *testing.T) {
{ {
name: "group results are sorted and made unique", name: "group results are sorted and made unique",
transforms: []IdentityTransformation{ transforms: []IdentityTransformation{
FakeAppendStringTransformer{}, fakeAppendStringTransformer{},
}, },
username: "foo", username: "foo",
groups: []string{ groups: []string{
@ -141,8 +177,8 @@ func TestTransformationPipeline(t *testing.T) {
"foobaz", "foobaz",
}, },
transforms: []IdentityTransformation{ transforms: []IdentityTransformation{
FakeAppendStringTransformer{}, fakeAppendStringTransformer{},
FakeAppendStringTransformer{}, fakeAppendStringTransformer{},
}, },
wantUsername: "foo:transformed:transformed", wantUsername: "foo:transformed:transformed",
wantGroups: []string{ wantGroups: []string{
@ -159,7 +195,7 @@ func TestTransformationPipeline(t *testing.T) {
"foobar", "foobar",
}, },
transforms: []IdentityTransformation{ transforms: []IdentityTransformation{
FakeAuthenticationDisallowedTransformer{}, fakeAuthenticationDisallowedTransformer{},
}, },
wantUsername: "foo:disallowed", wantUsername: "foo:disallowed",
wantGroups: []string{"foobar:disallowed"}, wantGroups: []string{"foobar:disallowed"},
@ -173,10 +209,10 @@ func TestTransformationPipeline(t *testing.T) {
"foobar", "foobar",
}, },
transforms: []IdentityTransformation{ transforms: []IdentityTransformation{
FakeAppendStringTransformer{}, fakeAppendStringTransformer{},
FakeAuthenticationDisallowedTransformer{}, fakeAuthenticationDisallowedTransformer{},
// this transformation will not be run because the previous exits the pipeline // this transformation will not be run because the previous exits the pipeline
FakeAppendStringTransformer{}, fakeAppendStringTransformer{},
}, },
wantUsername: "foo:transformed:disallowed", wantUsername: "foo:transformed:disallowed",
wantGroups: []string{"foobar:transformed:disallowed"}, wantGroups: []string{"foobar:transformed:disallowed"},
@ -190,9 +226,9 @@ func TestTransformationPipeline(t *testing.T) {
"foobar", "foobar",
}, },
transforms: []IdentityTransformation{ transforms: []IdentityTransformation{
FakeAppendStringTransformer{}, fakeAppendStringTransformer{},
FakeErrorTransformer{}, fakeErrorTransformer{},
FakeAppendStringTransformer{}, fakeAppendStringTransformer{},
}, },
wantError: "identity transformation at index 1: unexpected catastrophic error", wantError: "identity transformation at index 1: unexpected catastrophic error",
}, },
@ -200,7 +236,7 @@ func TestTransformationPipeline(t *testing.T) {
name: "empty username not allowed", name: "empty username not allowed",
username: "foo", username: "foo",
transforms: []IdentityTransformation{ transforms: []IdentityTransformation{
FakeDeleteUsernameAndGroupsTransformer{}, fakeDeleteUsernameAndGroupsTransformer{},
}, },
wantError: "identity transformation returned an empty username, which is not allowed", wantError: "identity transformation returned an empty username, which is not allowed",
}, },
@ -208,7 +244,7 @@ func TestTransformationPipeline(t *testing.T) {
name: "whitespace username not allowed", name: "whitespace username not allowed",
username: " \t\n\r ", username: " \t\n\r ",
transforms: []IdentityTransformation{ transforms: []IdentityTransformation{
FakeNoopTransformer{}, fakeNoopTransformer{},
}, },
wantError: "identity transformation returned an empty username, which is not allowed", wantError: "identity transformation returned an empty username, which is not allowed",
}, },
@ -217,7 +253,7 @@ func TestTransformationPipeline(t *testing.T) {
username: "foo", username: "foo",
groups: []string{}, groups: []string{},
transforms: []IdentityTransformation{ transforms: []IdentityTransformation{
FakeAppendStringTransformer{}, fakeAppendStringTransformer{},
}, },
wantUsername: "foo:transformed", wantUsername: "foo:transformed",
wantGroups: []string{}, wantGroups: []string{},
@ -229,7 +265,7 @@ func TestTransformationPipeline(t *testing.T) {
username: "foo", username: "foo",
groups: nil, groups: nil,
transforms: []IdentityTransformation{ transforms: []IdentityTransformation{
FakeNoopTransformer{}, fakeNoopTransformer{},
}, },
wantUsername: "foo", wantUsername: "foo",
wantGroups: []string{}, wantGroups: []string{},
@ -243,7 +279,7 @@ func TestTransformationPipeline(t *testing.T) {
"these.will.be.converted.to.nil", "these.will.be.converted.to.nil",
}, },
transforms: []IdentityTransformation{ transforms: []IdentityTransformation{
FakeNilGroupTransformer{}, fakeNilGroupTransformer{},
}, },
wantError: "identity transformation returned a null list of groups, which is not allowed", wantError: "identity transformation returned a null list of groups, which is not allowed",
}, },
@ -287,3 +323,18 @@ func TestTransformationPipeline(t *testing.T) {
}) })
} }
} }
func TestTransformationSource(t *testing.T) {
pipeline := NewTransformationPipeline()
for _, transform := range []IdentityTransformation{
&fakeTransformerWithSource{source: "foo"},
&fakeTransformerWithSource{source: "bar"},
&fakeTransformerWithSource{source: "baz"},
} {
pipeline.AppendTransformation(transform)
}
require.Equal(t, []interface{}{"foo", "bar", "baz"}, pipeline.Source())
require.NotEqual(t, []interface{}{"foo", "something-else", "baz"}, pipeline.Source())
}