diff --git a/internal/concierge/impersonator/impersonator.go b/internal/concierge/impersonator/impersonator.go index 2bd01574..9dda0189 100644 --- a/internal/concierge/impersonator/impersonator.go +++ b/internal/concierge/impersonator/impersonator.go @@ -4,6 +4,7 @@ package impersonator import ( + "context" "encoding/base64" "fmt" "net/http" @@ -144,18 +145,49 @@ func ensureNoImpersonationHeaders(r *http.Request) error { return nil } +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } + func getProxyHeaders(userInfo user.Info, requestHeaders http.Header) http.Header { newHeaders := http.Header{} - newHeaders.Set("Impersonate-User", userInfo.GetName()) - for _, group := range userInfo.GetGroups() { - newHeaders.Add("Impersonate-Group", group) + + // Leverage client-go's impersonation RoundTripper to set impersonation headers for us in the new + // request. The client-go RoundTripper not only sets all of the impersonation headers for us, but + // it also does some helpful escaping of characters that can't go into an HTTP header. To do this, + // we make a fake call to the impersonation RoundTripper with a fake HTTP request and a delegate + // RoundTripper that captures the impersonation headers set on the request. + impersonateConfig := transport.ImpersonationConfig{ + UserName: userInfo.GetName(), + Groups: userInfo.GetGroups(), + Extra: userInfo.GetExtra(), } + impersonateHeaderSpy := roundTripperFunc(func(r *http.Request) (*http.Response, error) { + newHeaders.Set(transport.ImpersonateUserHeader, r.Header.Get(transport.ImpersonateUserHeader)) + for _, groupHeaderValue := range r.Header.Values(transport.ImpersonateGroupHeader) { + newHeaders.Add(transport.ImpersonateGroupHeader, groupHeaderValue) + } + for headerKey, headerValues := range r.Header { + if strings.HasPrefix(headerKey, transport.ImpersonateUserExtraHeaderPrefix) { + for _, headerValue := range headerValues { + newHeaders.Add(headerKey, headerValue) + } + } + } + return nil, nil + }) + fakeReq, _ := http.NewRequestWithContext(context.Background(), "", "", nil) + //nolint:bodyclose // We return a nil http.Response above, so there is nothing to close. + _, _ = transport.NewImpersonatingRoundTripper(impersonateConfig, impersonateHeaderSpy).RoundTrip(fakeReq) + + // Copy over the allowed header values from the original request to the new request. for _, header := range allowedHeaders { values := requestHeaders.Values(header) for i := range values { newHeaders.Add(header, values[i]) } } + return newHeaders } diff --git a/internal/concierge/impersonator/impersonator_test.go b/internal/concierge/impersonator/impersonator_test.go index e993538c..8a279a7b 100644 --- a/internal/concierge/impersonator/impersonator_test.go +++ b/internal/concierge/impersonator/impersonator_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "reflect" "testing" "github.com/golang/mock/gomock" @@ -21,6 +22,7 @@ import ( "k8s.io/apiserver/pkg/authentication/user" "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd/api" + "k8s.io/client-go/transport" authenticationv1alpha1 "go.pinniped.dev/generated/1.20/apis/concierge/authentication/v1alpha1" "go.pinniped.dev/generated/1.20/apis/concierge/login" @@ -37,8 +39,20 @@ func TestImpersonator(t *testing.T) { const ( defaultAPIGroup = "pinniped.dev" customAPIGroup = "walrus.tld" + + testUser = "test-user" ) + testGroups := []string{"test-group-1", "test-group-2"} + testExtra := map[string][]string{ + "extra-1": {"some", "extra", "stuff"}, + "extra-2": {"some", "more", "extra", "stuff"}, + } + testExtraHeaders := map[string]string{ + "extra-1": transport.ImpersonateUserExtraHeaderPrefix + "extra-1", + "extra-2": transport.ImpersonateUserExtraHeaderPrefix + "extra-2", + } + validURL, _ := url.Parse("http://pinniped.dev/blah") testServerCA, testServerURL := testutil.TLSTestServer(t, func(w http.ResponseWriter, r *http.Request) { // Expect that the request is authenticated based on the kubeconfig credential. @@ -56,6 +70,25 @@ func TestImpersonator(t *testing.T) { http.Error(w, "got unexpected user agent header", http.StatusBadRequest) return } + // Ensure impersonation headers are set. + if values := r.Header.Values(transport.ImpersonateUserHeader); len(values) != 1 || values[0] != testUser { + message := fmt.Sprintf("got unexpected %q header: %q", transport.ImpersonateUserHeader, values) + http.Error(w, message, http.StatusBadRequest) + return + } + if values := r.Header.Values(transport.ImpersonateGroupHeader); !reflect.DeepEqual(testGroups, values) { + message := fmt.Sprintf("got unexpected %q headers: %q", transport.ImpersonateGroupHeader, values) + http.Error(w, message, http.StatusBadRequest) + return + } + for testExtraKey, testExtraValues := range testExtra { + header := testExtraHeaders[testExtraKey] + if values := r.Header.Values(header); !reflect.DeepEqual(testExtraValues, values) { + message := fmt.Sprintf("got unexpected %q headers: %q", header, values) + http.Error(w, message, http.StatusBadRequest) + return + } + } _, _ = w.Write([]byte("successful proxied response")) }) testServerKubeconfig := rest.Config{ @@ -230,9 +263,10 @@ func TestImpersonator(t *testing.T) { }), expectMockToken: func(t *testing.T, recorder *mocktokenauthenticator.MockTokenMockRecorder) { userInfo := user.DefaultInfo{ - Name: "test-user", - Groups: []string{"test-group-1", "test-group-2"}, + Name: testUser, + Groups: testGroups, UID: "test-uid", + Extra: testExtra, } response := &authenticator.Response{User: &userInfo} recorder.AuthenticateToken(gomock.Any(), "test-token").Return(response, true, nil) @@ -252,9 +286,10 @@ func TestImpersonator(t *testing.T) { }), expectMockToken: func(t *testing.T, recorder *mocktokenauthenticator.MockTokenMockRecorder) { userInfo := user.DefaultInfo{ - Name: "test-user", - Groups: []string{"test-group-1", "test-group-2"}, + Name: testUser, + Groups: testGroups, UID: "test-uid", + Extra: testExtra, } response := &authenticator.Response{User: &userInfo} recorder.AuthenticateToken(gomock.Any(), "test-token").Return(response, true, nil) @@ -306,7 +341,7 @@ func TestImpersonator(t *testing.T) { proxy.ServeHTTP(w, tt.request) require.Equal(t, requestBeforeServe, tt.request, "ServeHTTP() mutated the request, and it should not per http.Handler docs") if tt.wantHTTPStatus != 0 { - require.Equal(t, tt.wantHTTPStatus, w.Code) + require.Equalf(t, tt.wantHTTPStatus, w.Code, "fyi, response body was %q", w.Body.String()) } if tt.wantHTTPBody != "" { require.Equal(t, tt.wantHTTPBody, w.Body.String())