internal/concierge/impersonator: set user extra impersonation headers

Signed-off-by: Andrew Keesler <akeesler@vmware.com>
This commit is contained in:
Andrew Keesler 2021-02-16 09:09:54 -05:00
parent c7905c6638
commit eb19980110
No known key found for this signature in database
GPG Key ID: 27CE0444346F9413
2 changed files with 75 additions and 8 deletions

View File

@ -4,6 +4,7 @@
package impersonator package impersonator
import ( import (
"context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
@ -144,18 +145,49 @@ func ensureNoImpersonationHeaders(r *http.Request) error {
return nil return nil
} }
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
func getProxyHeaders(userInfo user.Info, requestHeaders http.Header) http.Header { func getProxyHeaders(userInfo user.Info, requestHeaders http.Header) http.Header {
newHeaders := http.Header{} newHeaders := http.Header{}
newHeaders.Set("Impersonate-User", userInfo.GetName())
for _, group := range userInfo.GetGroups() { // Leverage client-go's impersonation RoundTripper to set impersonation headers for us in the new
newHeaders.Add("Impersonate-Group", group) // request. The client-go RoundTripper not only sets all of the impersonation headers for us, but
// it also does some helpful escaping of characters that can't go into an HTTP header. To do this,
// we make a fake call to the impersonation RoundTripper with a fake HTTP request and a delegate
// RoundTripper that captures the impersonation headers set on the request.
impersonateConfig := transport.ImpersonationConfig{
UserName: userInfo.GetName(),
Groups: userInfo.GetGroups(),
Extra: userInfo.GetExtra(),
} }
impersonateHeaderSpy := roundTripperFunc(func(r *http.Request) (*http.Response, error) {
newHeaders.Set(transport.ImpersonateUserHeader, r.Header.Get(transport.ImpersonateUserHeader))
for _, groupHeaderValue := range r.Header.Values(transport.ImpersonateGroupHeader) {
newHeaders.Add(transport.ImpersonateGroupHeader, groupHeaderValue)
}
for headerKey, headerValues := range r.Header {
if strings.HasPrefix(headerKey, transport.ImpersonateUserExtraHeaderPrefix) {
for _, headerValue := range headerValues {
newHeaders.Add(headerKey, headerValue)
}
}
}
return nil, nil
})
fakeReq, _ := http.NewRequestWithContext(context.Background(), "", "", nil)
//nolint:bodyclose // We return a nil http.Response above, so there is nothing to close.
_, _ = transport.NewImpersonatingRoundTripper(impersonateConfig, impersonateHeaderSpy).RoundTrip(fakeReq)
// Copy over the allowed header values from the original request to the new request.
for _, header := range allowedHeaders { for _, header := range allowedHeaders {
values := requestHeaders.Values(header) values := requestHeaders.Values(header)
for i := range values { for i := range values {
newHeaders.Add(header, values[i]) newHeaders.Add(header, values[i])
} }
} }
return newHeaders return newHeaders
} }

View File

@ -9,6 +9,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/url" "net/url"
"reflect"
"testing" "testing"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
@ -21,6 +22,7 @@ import (
"k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/authentication/user"
"k8s.io/client-go/rest" "k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd/api" "k8s.io/client-go/tools/clientcmd/api"
"k8s.io/client-go/transport"
authenticationv1alpha1 "go.pinniped.dev/generated/1.20/apis/concierge/authentication/v1alpha1" authenticationv1alpha1 "go.pinniped.dev/generated/1.20/apis/concierge/authentication/v1alpha1"
"go.pinniped.dev/generated/1.20/apis/concierge/login" "go.pinniped.dev/generated/1.20/apis/concierge/login"
@ -37,8 +39,20 @@ func TestImpersonator(t *testing.T) {
const ( const (
defaultAPIGroup = "pinniped.dev" defaultAPIGroup = "pinniped.dev"
customAPIGroup = "walrus.tld" customAPIGroup = "walrus.tld"
testUser = "test-user"
) )
testGroups := []string{"test-group-1", "test-group-2"}
testExtra := map[string][]string{
"extra-1": {"some", "extra", "stuff"},
"extra-2": {"some", "more", "extra", "stuff"},
}
testExtraHeaders := map[string]string{
"extra-1": transport.ImpersonateUserExtraHeaderPrefix + "extra-1",
"extra-2": transport.ImpersonateUserExtraHeaderPrefix + "extra-2",
}
validURL, _ := url.Parse("http://pinniped.dev/blah") validURL, _ := url.Parse("http://pinniped.dev/blah")
testServerCA, testServerURL := testutil.TLSTestServer(t, func(w http.ResponseWriter, r *http.Request) { testServerCA, testServerURL := testutil.TLSTestServer(t, func(w http.ResponseWriter, r *http.Request) {
// Expect that the request is authenticated based on the kubeconfig credential. // Expect that the request is authenticated based on the kubeconfig credential.
@ -56,6 +70,25 @@ func TestImpersonator(t *testing.T) {
http.Error(w, "got unexpected user agent header", http.StatusBadRequest) http.Error(w, "got unexpected user agent header", http.StatusBadRequest)
return return
} }
// Ensure impersonation headers are set.
if values := r.Header.Values(transport.ImpersonateUserHeader); len(values) != 1 || values[0] != testUser {
message := fmt.Sprintf("got unexpected %q header: %q", transport.ImpersonateUserHeader, values)
http.Error(w, message, http.StatusBadRequest)
return
}
if values := r.Header.Values(transport.ImpersonateGroupHeader); !reflect.DeepEqual(testGroups, values) {
message := fmt.Sprintf("got unexpected %q headers: %q", transport.ImpersonateGroupHeader, values)
http.Error(w, message, http.StatusBadRequest)
return
}
for testExtraKey, testExtraValues := range testExtra {
header := testExtraHeaders[testExtraKey]
if values := r.Header.Values(header); !reflect.DeepEqual(testExtraValues, values) {
message := fmt.Sprintf("got unexpected %q headers: %q", header, values)
http.Error(w, message, http.StatusBadRequest)
return
}
}
_, _ = w.Write([]byte("successful proxied response")) _, _ = w.Write([]byte("successful proxied response"))
}) })
testServerKubeconfig := rest.Config{ testServerKubeconfig := rest.Config{
@ -230,9 +263,10 @@ func TestImpersonator(t *testing.T) {
}), }),
expectMockToken: func(t *testing.T, recorder *mocktokenauthenticator.MockTokenMockRecorder) { expectMockToken: func(t *testing.T, recorder *mocktokenauthenticator.MockTokenMockRecorder) {
userInfo := user.DefaultInfo{ userInfo := user.DefaultInfo{
Name: "test-user", Name: testUser,
Groups: []string{"test-group-1", "test-group-2"}, Groups: testGroups,
UID: "test-uid", UID: "test-uid",
Extra: testExtra,
} }
response := &authenticator.Response{User: &userInfo} response := &authenticator.Response{User: &userInfo}
recorder.AuthenticateToken(gomock.Any(), "test-token").Return(response, true, nil) recorder.AuthenticateToken(gomock.Any(), "test-token").Return(response, true, nil)
@ -252,9 +286,10 @@ func TestImpersonator(t *testing.T) {
}), }),
expectMockToken: func(t *testing.T, recorder *mocktokenauthenticator.MockTokenMockRecorder) { expectMockToken: func(t *testing.T, recorder *mocktokenauthenticator.MockTokenMockRecorder) {
userInfo := user.DefaultInfo{ userInfo := user.DefaultInfo{
Name: "test-user", Name: testUser,
Groups: []string{"test-group-1", "test-group-2"}, Groups: testGroups,
UID: "test-uid", UID: "test-uid",
Extra: testExtra,
} }
response := &authenticator.Response{User: &userInfo} response := &authenticator.Response{User: &userInfo}
recorder.AuthenticateToken(gomock.Any(), "test-token").Return(response, true, nil) recorder.AuthenticateToken(gomock.Any(), "test-token").Return(response, true, nil)
@ -306,7 +341,7 @@ func TestImpersonator(t *testing.T) {
proxy.ServeHTTP(w, tt.request) proxy.ServeHTTP(w, tt.request)
require.Equal(t, requestBeforeServe, tt.request, "ServeHTTP() mutated the request, and it should not per http.Handler docs") require.Equal(t, requestBeforeServe, tt.request, "ServeHTTP() mutated the request, and it should not per http.Handler docs")
if tt.wantHTTPStatus != 0 { if tt.wantHTTPStatus != 0 {
require.Equal(t, tt.wantHTTPStatus, w.Code) require.Equalf(t, tt.wantHTTPStatus, w.Code, "fyi, response body was %q", w.Body.String())
} }
if tt.wantHTTPBody != "" { if tt.wantHTTPBody != "" {
require.Equal(t, tt.wantHTTPBody, w.Body.String()) require.Equal(t, tt.wantHTTPBody, w.Body.String())