From e86488615ad842727122b7386ebf5ffdb5887c67 Mon Sep 17 00:00:00 2001 From: Monis Khan Date: Tue, 28 Sep 2021 11:29:20 -0400 Subject: [PATCH] upstreamoidc: directly detect user info support Avoid reliance on an error string from the Core OS OIDC lib. Signed-off-by: Monis Khan --- internal/upstreamoidc/upstreamoidc.go | 20 +++++++---- internal/upstreamoidc/upstreamoidc_test.go | 41 +++++++++++++++++----- 2 files changed, 46 insertions(+), 15 deletions(-) diff --git a/internal/upstreamoidc/upstreamoidc.go b/internal/upstreamoidc/upstreamoidc.go index c8d3722f..659246ab 100644 --- a/internal/upstreamoidc/upstreamoidc.go +++ b/internal/upstreamoidc/upstreamoidc.go @@ -37,6 +37,7 @@ type ProviderConfig struct { AllowPasswordGrant bool Provider interface { Verifier(*coreosoidc.Config) *coreosoidc.IDTokenVerifier + Claims(v interface{}) error UserInfo(ctx context.Context, tokenSource oauth2.TokenSource) (*coreosoidc.UserInfo, error) } } @@ -160,14 +161,21 @@ func (p *ProviderConfig) fetchUserInfo(ctx context.Context, tok *oauth2.Token, c return nil // defer to existing ID token validation } + providerJSON := &struct { + UserInfoURL string `json:"userinfo_endpoint"` + }{} + if err := p.Provider.Claims(providerJSON); err != nil { + // this should never happen because we should have already parsed these claims at an earlier stage + return httperr.Wrap(http.StatusInternalServerError, "could not unmarshal discovery JSON", err) + } + + // implementing the user info endpoint is not required, skip this logic when it is absent + if len(providerJSON.UserInfoURL) == 0 { + return nil + } + userInfo, err := p.Provider.UserInfo(coreosoidc.ClientContext(ctx, p.Client), oauth2.StaticTokenSource(tok)) if err != nil { - // the user info endpoint is not required but we do not have a good way to probe if it was provided - const userInfoUnsupported = "oidc: user info endpoint is not supported by this provider" - if err.Error() == userInfoUnsupported { - return nil - } - return httperr.Wrap(http.StatusInternalServerError, "could not get user info", err) } diff --git a/internal/upstreamoidc/upstreamoidc_test.go b/internal/upstreamoidc/upstreamoidc_test.go index 801207bf..2ffd40ca 100644 --- a/internal/upstreamoidc/upstreamoidc_test.go +++ b/internal/upstreamoidc/upstreamoidc_test.go @@ -70,9 +70,6 @@ func TestProviderConfig(t *testing.T) { validIDToken = "eyJhbGciOiJSUzI1NiIsImtpZCI6InRlc3Qta2lkIiwidHlwIjoiSldUIn0.eyJhdWQiOiJ0ZXN0LWNsaWVudC1pZCIsImJhdCI6ImJheiIsImZvbyI6ImJhciIsImlhdCI6MTYwNjc2ODU5MywianRpIjoidGVzdC1qdGkiLCJuYmYiOjE2MDY3Njg1OTMsInN1YiI6InRlc3QtdXNlciJ9.DuqVZ7pGhHqKz7gNr4j2W1s1N8YrSltktH4wW19L4oD1OE2-O72jAnNj5xdjilsa8l7h9ox-5sMF0Tkh3BdRlHQK9dEtNm9tW-JreUnWJ3LCqUs-LZp4NG7edvq2sH_1Bn7O2_NQV51s8Pl04F60CndjQ4NM-6WkqDQTKyY6vJXU7idvM-6TM2HJZK-Na88cOJ9KIK37tL5DhcbsHVF47Dq8uPZ0KbjNQjJLAIi_1GeQBgc6yJhDUwRY4Xu6S0dtTHA6xTI8oSXoamt4bkViEHfJBp97LZQiNz8mku5pVc0aNwP1p4hMHxRHhLXrJjbh-Hx4YFjxtOnIq9t1mHlD4A" //nolint: gosec ) - // if the error string for unsupported user info changes, this will hopefully catch it - _, userInfoNotSupported := (&oidc.Provider{}).UserInfo(context.Background(), nil) - t.Run("PasswordCredentialsGrantAndValidateTokens", func(t *testing.T) { tests := []struct { name string @@ -82,6 +79,7 @@ func TestProviderConfig(t *testing.T) { wantErr string wantToken oidctypes.Token + rawClaims []byte userInfo *oidc.UserInfo userInfoErr error wantUserInfoCalled bool @@ -111,8 +109,8 @@ func TestProviderConfig(t *testing.T) { }, }, }, - userInfoErr: userInfoNotSupported, - wantUserInfoCalled: true, + rawClaims: []byte(`{}`), // user info not supported + wantUserInfoCalled: false, }, { name: "valid with userinfo", @@ -245,6 +243,11 @@ func TestProviderConfig(t *testing.T) { })) t.Cleanup(tokenServer.Close) + rawClaims := tt.rawClaims + if len(rawClaims) == 0 && (tt.userInfo != nil || tt.userInfoErr != nil) { + rawClaims = []byte(`{"userinfo_endpoint": "not-empty"}`) + } + p := ProviderConfig{ Name: "test-name", UsernameClaim: "test-username-claim", @@ -260,6 +263,7 @@ func TestProviderConfig(t *testing.T) { Scopes: []string{"scope1", "scope2"}, }, Provider: &mockProvider{ + rawClaims: rawClaims, userInfo: tt.userInfo, userInfoErr: tt.userInfoErr, }, @@ -293,6 +297,7 @@ func TestProviderConfig(t *testing.T) { wantErr string wantToken oidctypes.Token + rawClaims []byte userInfo *oidc.UserInfo userInfoErr error wantUserInfoCalled bool @@ -352,8 +357,8 @@ func TestProviderConfig(t *testing.T) { }, }, }, - userInfoErr: userInfoNotSupported, - wantUserInfoCalled: true, + rawClaims: []byte(`{}`), // user info not supported + wantUserInfoCalled: false, }, { name: "valid", @@ -381,8 +386,15 @@ func TestProviderConfig(t *testing.T) { }, }, }, - userInfoErr: userInfoNotSupported, - wantUserInfoCalled: true, + rawClaims: []byte(`{}`), // user info not supported + wantUserInfoCalled: false, + }, + { + name: "user info discovery parse error", + authCode: "valid", + returnIDTok: validIDToken, + rawClaims: []byte(`junk`), // user info discovery fails + wantErr: "could not fetch user info claims: could not unmarshal discovery JSON: invalid character 'j' looking for beginning of value", }, { name: "user info fetch error", @@ -496,6 +508,11 @@ func TestProviderConfig(t *testing.T) { })) t.Cleanup(tokenServer.Close) + rawClaims := tt.rawClaims + if len(rawClaims) == 0 && (tt.userInfo != nil || tt.userInfoErr != nil) { + rawClaims = []byte(`{"userinfo_endpoint": "not-empty"}`) + } + p := ProviderConfig{ Name: "test-name", UsernameClaim: "test-username-claim", @@ -511,6 +528,7 @@ func TestProviderConfig(t *testing.T) { Scopes: []string{"scope1", "scope2"}, }, Provider: &mockProvider{ + rawClaims: rawClaims, userInfo: tt.userInfo, userInfoErr: tt.userInfoErr, }, @@ -559,12 +577,17 @@ func mockVerifier() *oidc.IDTokenVerifier { type mockProvider struct { called bool + rawClaims []byte userInfo *oidc.UserInfo userInfoErr error } func (m *mockProvider) Verifier(_ *oidc.Config) *oidc.IDTokenVerifier { return mockVerifier() } +func (m *mockProvider) Claims(v interface{}) error { + return json.Unmarshal(m.rawClaims, v) +} + func (m *mockProvider) UserInfo(_ context.Context, tokenSource oauth2.TokenSource) (*oidc.UserInfo, error) { m.called = true