internal/kubeclient: fix not found test and request body closing bug

- I realized that the hardcoded fakekubeapi 404 not found response was invalid,
  so we were getting a default error message. I fixed it so the tests follow a
  higher fidelity code path.
- I caved and added a test for making sure the request body was always closed,
  and believe it or not, we were double closing a body. I don't *think* this will
  matter in production, since client-go will pass us ioutil.NopReader()'s, but
  at least we know now.

Signed-off-by: Andrew Keesler <akeesler@vmware.com>
This commit is contained in:
Andrew Keesler 2021-02-03 08:19:34 -05:00
parent efe1fa89fe
commit 62c117421a
No known key found for this signature in database
GPG Key ID: 27CE0444346F9413
5 changed files with 104 additions and 62 deletions

View File

@ -51,13 +51,13 @@ func New(opts ...Option) (*Client, error) {
protoKubeConfig := createProtoKubeConfig(c.config) protoKubeConfig := createProtoKubeConfig(c.config)
// Connect to the core Kubernetes API. // Connect to the core Kubernetes API.
k8sClient, err := kubernetes.NewForConfig(configWithWrapper(protoKubeConfig, kubescheme.Scheme, kubescheme.Codecs, c.middlewares)) k8sClient, err := kubernetes.NewForConfig(configWithWrapper(protoKubeConfig, kubescheme.Scheme, kubescheme.Codecs, c.middlewares, c.transportWrapper))
if err != nil { if err != nil {
return nil, fmt.Errorf("could not initialize Kubernetes client: %w", err) return nil, fmt.Errorf("could not initialize Kubernetes client: %w", err)
} }
// Connect to the Kubernetes aggregation API. // Connect to the Kubernetes aggregation API.
aggregatorClient, err := aggregatorclient.NewForConfig(configWithWrapper(protoKubeConfig, aggregatorclientscheme.Scheme, aggregatorclientscheme.Codecs, c.middlewares)) aggregatorClient, err := aggregatorclient.NewForConfig(configWithWrapper(protoKubeConfig, aggregatorclientscheme.Scheme, aggregatorclientscheme.Codecs, c.middlewares, c.transportWrapper))
if err != nil { if err != nil {
return nil, fmt.Errorf("could not initialize aggregation client: %w", err) return nil, fmt.Errorf("could not initialize aggregation client: %w", err)
} }
@ -65,7 +65,7 @@ func New(opts ...Option) (*Client, error) {
// Connect to the pinniped concierge API. // Connect to the pinniped concierge API.
// We cannot use protobuf encoding here because we are using CRDs // We cannot use protobuf encoding here because we are using CRDs
// (for which protobuf encoding is not yet supported). // (for which protobuf encoding is not yet supported).
pinnipedConciergeClient, err := pinnipedconciergeclientset.NewForConfig(configWithWrapper(jsonKubeConfig, pinnipedconciergeclientsetscheme.Scheme, pinnipedconciergeclientsetscheme.Codecs, c.middlewares)) pinnipedConciergeClient, err := pinnipedconciergeclientset.NewForConfig(configWithWrapper(jsonKubeConfig, pinnipedconciergeclientsetscheme.Scheme, pinnipedconciergeclientsetscheme.Codecs, c.middlewares, c.transportWrapper))
if err != nil { if err != nil {
return nil, fmt.Errorf("could not initialize pinniped client: %w", err) return nil, fmt.Errorf("could not initialize pinniped client: %w", err)
} }
@ -73,7 +73,7 @@ func New(opts ...Option) (*Client, error) {
// Connect to the pinniped supervisor API. // Connect to the pinniped supervisor API.
// We cannot use protobuf encoding here because we are using CRDs // We cannot use protobuf encoding here because we are using CRDs
// (for which protobuf encoding is not yet supported). // (for which protobuf encoding is not yet supported).
pinnipedSupervisorClient, err := pinnipedsupervisorclientset.NewForConfig(configWithWrapper(jsonKubeConfig, pinnipedsupervisorclientsetscheme.Scheme, pinnipedsupervisorclientsetscheme.Codecs, c.middlewares)) pinnipedSupervisorClient, err := pinnipedsupervisorclientset.NewForConfig(configWithWrapper(jsonKubeConfig, pinnipedsupervisorclientsetscheme.Scheme, pinnipedsupervisorclientsetscheme.Codecs, c.middlewares, c.transportWrapper))
if err != nil { if err != nil {
return nil, fmt.Errorf("could not initialize pinniped client: %w", err) return nil, fmt.Errorf("could not initialize pinniped client: %w", err)
} }

View File

@ -66,12 +66,6 @@ var (
middlewareLabels = map[string]string{"some-label": "thing 2"} middlewareLabels = map[string]string{"some-label": "thing 2"}
) )
// TestKubeclient tests a subset of kubeclient functionality (from the public interface down). We
// intend for the following list of things to be tested with the integration tests:
// list (running in every informer cache)
// watch (running in every informer cache)
// discovery
// api errors
func TestKubeclient(t *testing.T) { func TestKubeclient(t *testing.T) {
// plog.ValidateAndSetLogLevelGlobally(plog.LevelDebug) // uncomment me to get some more debug logs // plog.ValidateAndSetLogLevelGlobally(plog.LevelDebug) // uncomment me to get some more debug logs
@ -109,7 +103,7 @@ func TestKubeclient(t *testing.T) {
CoreV1(). CoreV1().
Pods(pod.Namespace). Pods(pod.Namespace).
Get(context.Background(), "this-pod-does-not-exist", metav1.GetOptions{}) Get(context.Background(), "this-pod-does-not-exist", metav1.GetOptions{})
require.EqualError(t, err, "the server could not find the requested resource (get pods this-pod-does-not-exist)") require.EqualError(t, err, `couldn't find object for path "/api/v1/namespaces/good-namespace/pods/this-pod-does-not-exist"`)
// update // update
goodPodWithAnnotationsAndLabelsAndClusterName := with(goodPod, annotations(), labels(), clusterName()).(*corev1.Pod) goodPodWithAnnotationsAndLabelsAndClusterName := with(goodPod, annotations(), labels(), clusterName()).(*corev1.Pod)
@ -546,16 +540,15 @@ func TestKubeclient(t *testing.T) {
test.editRestConfig(t, restConfig) test.editRestConfig(t, restConfig)
} }
// our rt chain is:
// kubeclient -> wantCloseResp -> http.DefaultTransport -> wantCloseResp -> kubeclient
restConfig.Wrap(wantCloseRespWrapper(t))
var middlewares []*spyMiddleware var middlewares []*spyMiddleware
if test.middlewares != nil { if test.middlewares != nil {
middlewares = test.middlewares(t) middlewares = test.middlewares(t)
} }
opts := []Option{WithConfig(restConfig)} // our rt chain is:
// wantCloseReq -> kubeclient -> wantCloseResp -> http.DefaultTransport -> wantCloseResp -> kubeclient -> wantCloseReq
restConfig.Wrap(wantCloseRespWrapper(t))
opts := []Option{WithConfig(restConfig), WithTransportWrapper(wantCloseReqWrapper(t))}
for _, middleware := range middlewares { for _, middleware := range middlewares {
opts = append(opts, WithMiddleware(middleware)) opts = append(opts, WithMiddleware(middleware))
} }
@ -675,11 +668,13 @@ func newSimpleMiddleware(t *testing.T, hasMutateReqFunc, mutatedReq, hasMutateRe
type wantCloser struct { type wantCloser struct {
io.ReadCloser io.ReadCloser
closeCount int closeCount int
closeCalls []string
couldReadBytesJustBeforeClosing bool couldReadBytesJustBeforeClosing bool
} }
func (wc *wantCloser) Close() error { func (wc *wantCloser) Close() error {
wc.closeCount++ wc.closeCount++
wc.closeCalls = append(wc.closeCalls, getCaller())
n, _ := wc.ReadCloser.Read([]byte{0}) n, _ := wc.ReadCloser.Read([]byte{0})
if n > 0 { if n > 0 {
// there were still bytes left to be read // there were still bytes left to be read
@ -688,14 +683,53 @@ func (wc *wantCloser) Close() error {
return wc.ReadCloser.Close() return wc.ReadCloser.Close()
} }
// wantCloseRespWrapper returns a transport.WrapperFunc that validates that the http.Response func getCaller() string {
// returned by the underlying http.RoundTripper is closed properly. _, file, line, ok := runtime.Caller(2)
func wantCloseRespWrapper(t *testing.T) transport.WrapperFunc {
_, file, line, ok := runtime.Caller(1)
if !ok { if !ok {
file = "???" file = "???"
line = 0 line = 0
} }
return fmt.Sprintf("%s:%d", file, line)
}
// wantCloseReqWrapper returns a transport.WrapperFunc that validates that the http.Request
// passed to the underlying http.RoundTripper is closed properly.
func wantCloseReqWrapper(t *testing.T) transport.WrapperFunc {
caller := getCaller()
return func(rt http.RoundTripper) http.RoundTripper {
return roundTripperFunc(func(req *http.Request) (bool, *http.Response, error) {
if req.Body != nil {
wc := &wantCloser{ReadCloser: req.Body}
t.Cleanup(func() {
require.Equalf(t, wc.closeCount, 1, "did not close req body expected number of times at %s for req %#v; actual calls = %s", caller, req, wc.closeCalls)
})
req.Body = wc
}
if req.GetBody != nil {
originalBodyCopy, originalErr := req.GetBody()
req.GetBody = func() (io.ReadCloser, error) {
if originalErr != nil {
return nil, originalErr
}
wc := &wantCloser{ReadCloser: originalBodyCopy}
t.Cleanup(func() {
require.Equalf(t, wc.closeCount, 1, "did not close req body copy expected number of times at %s for req %#v; actual calls = %s", caller, req, wc.closeCalls)
})
return wc, nil
}
}
resp, err := rt.RoundTrip(req)
return false, resp, err
})
}
}
// wantCloseRespWrapper returns a transport.WrapperFunc that validates that the http.Response
// returned by the underlying http.RoundTripper is closed properly.
func wantCloseRespWrapper(t *testing.T) transport.WrapperFunc {
caller := getCaller()
return func(rt http.RoundTripper) http.RoundTripper { return func(rt http.RoundTripper) http.RoundTripper {
return roundTripperFunc(func(req *http.Request) (bool, *http.Response, error) { return roundTripperFunc(func(req *http.Request) (bool, *http.Response, error) {
resp, err := rt.RoundTrip(req) resp, err := rt.RoundTrip(req)
@ -705,8 +739,8 @@ func wantCloseRespWrapper(t *testing.T) transport.WrapperFunc {
} }
wc := &wantCloser{ReadCloser: resp.Body} wc := &wantCloser{ReadCloser: resp.Body}
t.Cleanup(func() { t.Cleanup(func() {
require.False(t, wc.couldReadBytesJustBeforeClosing, "did not consume all response body bytes before closing %s:%d", file, line) require.False(t, wc.couldReadBytesJustBeforeClosing, "did not consume all response body bytes before closing %s", caller)
require.Equalf(t, wc.closeCount, 1, "did not close resp body at %s:%d", file, line) require.Equalf(t, wc.closeCount, 1, "did not close resp body expected number of times at %s for req %#v; actual calls = %s", caller, req, wc.closeCalls)
}) })
resp.Body = wc resp.Body = wc
return false, resp, err return false, resp, err

View File

@ -3,13 +3,17 @@
package kubeclient package kubeclient
import restclient "k8s.io/client-go/rest" import (
restclient "k8s.io/client-go/rest"
"k8s.io/client-go/transport"
)
type Option func(*clientConfig) type Option func(*clientConfig)
type clientConfig struct { type clientConfig struct {
config *restclient.Config config *restclient.Config
middlewares []Middleware middlewares []Middleware
transportWrapper transport.WrapperFunc
} }
func WithConfig(config *restclient.Config) Option { func WithConfig(config *restclient.Config) Option {
@ -27,3 +31,12 @@ func WithMiddleware(middleware Middleware) Option {
c.middlewares = append(c.middlewares, middleware) c.middlewares = append(c.middlewares, middleware)
} }
} }
// WithTransportWrapper will wrap the client-go http.RoundTripper chain *after* the middleware
// wrapper is applied. I.e., this wrapper has the opportunity to supply an http.RoundTripper that
// runs first in the client-go http.RoundTripper chain.
func WithTransportWrapper(wrapper transport.WrapperFunc) Option {
return func(c *clientConfig) {
c.transportWrapper = wrapper
}
}

View File

@ -23,7 +23,7 @@ import (
"go.pinniped.dev/internal/plog" "go.pinniped.dev/internal/plog"
) )
func configWithWrapper(config *restclient.Config, scheme *runtime.Scheme, negotiatedSerializer runtime.NegotiatedSerializer, middlewares []Middleware) *restclient.Config { func configWithWrapper(config *restclient.Config, scheme *runtime.Scheme, negotiatedSerializer runtime.NegotiatedSerializer, middlewares []Middleware, wrapper transport.WrapperFunc) *restclient.Config {
hostURL, apiPathPrefix, err := getHostAndAPIPathPrefix(config) hostURL, apiPathPrefix, err := getHostAndAPIPathPrefix(config)
if err != nil { if err != nil {
plog.DebugErr("invalid rest config", err) plog.DebugErr("invalid rest config", err)
@ -49,6 +49,9 @@ func configWithWrapper(config *restclient.Config, scheme *runtime.Scheme, negoti
cc := restclient.CopyConfig(config) cc := restclient.CopyConfig(config)
cc.Wrap(f) cc.Wrap(f)
if wrapper != nil {
cc.Wrap(wrapper)
}
return cc return cc
} }
@ -173,20 +176,20 @@ func handleOtherVerbs(
resp, err := rt.RoundTrip(newReq) resp, err := rt.RoundTrip(newReq)
if err != nil { if err != nil {
return true, nil, fmt.Errorf("middleware request for %#v failed: %w", middlewareReq, err) return false, nil, fmt.Errorf("middleware request for %#v failed: %w", middlewareReq, err)
} }
switch v { switch v {
case VerbDelete, VerbDeleteCollection: case VerbDelete, VerbDeleteCollection:
return true, resp, nil // we do not need to fix the response on delete return false, resp, nil // we do not need to fix the response on delete
case VerbWatch: case VerbWatch:
resp, err := handleWatchResponseNewGVK(config, negotiatedSerializer, resp, middlewareReq, result) resp, err := handleWatchResponseNewGVK(config, negotiatedSerializer, resp, middlewareReq, result)
return true, resp, err return false, resp, err
default: // VerbGet, VerbList, VerbPatch default: // VerbGet, VerbList, VerbPatch
resp, err := handleResponseNewGVK(config, negotiatedSerializer, resp, middlewareReq, result) resp, err := handleResponseNewGVK(config, negotiatedSerializer, resp, middlewareReq, result)
return true, resp, err return false, resp, err
} }
} }

View File

@ -19,6 +19,7 @@ package fakekubeapi
import ( import (
"encoding/pem" "encoding/pem"
"fmt"
"io/ioutil" "io/ioutil"
"mime" "mime"
"net/http" "net/http"
@ -39,20 +40,6 @@ import (
"go.pinniped.dev/internal/multierror" "go.pinniped.dev/internal/multierror"
) )
// Unlike the standard httperr.New(), this one does not prepend error messages with any prefix.
type plainHTTPErr struct {
code int
msg string
}
func (e plainHTTPErr) Error() string {
return e.msg
}
func (e plainHTTPErr) Respond(w http.ResponseWriter) {
http.Error(w, e.msg, e.code)
}
// Start starts an httptest.Server (with TLS) that pretends to be a Kube API server. // Start starts an httptest.Server (with TLS) that pretends to be a Kube API server.
// //
// The server uses the provided resources map to store API Object's. The map should be from API path // The server uses the provided resources map to store API Object's. The map should be from API path
@ -62,9 +49,9 @@ func (e plainHTTPErr) Respond(w http.ResponseWriter) {
// to the server. // to the server.
// //
// Note! Only these following verbs are (partially) supported: create, get, update, delete. // Note! Only these following verbs are (partially) supported: create, get, update, delete.
func Start(t *testing.T, resources map[string]metav1.Object) (*httptest.Server, *restclient.Config) { func Start(t *testing.T, resources map[string]runtime.Object) (*httptest.Server, *restclient.Config) {
if resources == nil { if resources == nil {
resources = make(map[string]metav1.Object) resources = make(map[string]runtime.Object)
} }
server := httptest.NewTLSServer(httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (err error) { server := httptest.NewTLSServer(httperr.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (err error) {
@ -78,12 +65,8 @@ func Start(t *testing.T, resources map[string]metav1.Object) (*httptest.Server,
return err return err
} }
if r.Method != http.MethodDelete && obj == nil { if obj == nil {
return &plainHTTPErr{ obj = newNotFoundStatus(r.URL.Path)
code: http.StatusNotFound,
// This is representative of a real Kube 404 message body.
msg: `{"kind":"Status","apiVersion":"v1","metadata":{},"status":"Failure","message":"not found","reason":"NotFound","details":{"name":"not-found","kind":"pods"},"code":404}`,
}
} }
if err := encodeObj(w, r, obj); err != nil { if err := encodeObj(w, r, obj); err != nil {
@ -101,7 +84,7 @@ func Start(t *testing.T, resources map[string]metav1.Object) (*httptest.Server,
return server, restConfig return server, restConfig
} }
func decodeObj(r *http.Request) (metav1.Object, error) { func decodeObj(r *http.Request) (runtime.Object, error) {
switch r.Method { switch r.Method {
case http.MethodPut, http.MethodPost: case http.MethodPut, http.MethodPost:
default: default:
@ -123,7 +106,7 @@ func decodeObj(r *http.Request) (metav1.Object, error) {
return nil, httperr.Wrap(http.StatusInternalServerError, "read body", err) return nil, httperr.Wrap(http.StatusInternalServerError, "read body", err)
} }
var obj metav1.Object var obj runtime.Object
multiErr := multierror.New() multiErr := multierror.New()
codecsThatWeUseInOurCode := []runtime.NegotiatedSerializer{ codecsThatWeUseInOurCode := []runtime.NegotiatedSerializer{
kubescheme.Codecs, kubescheme.Codecs,
@ -145,7 +128,7 @@ func tryDecodeObj(
mediaType string, mediaType string,
body []byte, body []byte,
negotiatedSerializer runtime.NegotiatedSerializer, negotiatedSerializer runtime.NegotiatedSerializer,
) (metav1.Object, error) { ) (runtime.Object, error) {
serializerInfo, ok := runtime.SerializerInfoForMediaType(negotiatedSerializer.SupportedMediaTypes(), mediaType) serializerInfo, ok := runtime.SerializerInfoForMediaType(negotiatedSerializer.SupportedMediaTypes(), mediaType)
if !ok { if !ok {
return nil, httperr.Newf(http.StatusInternalServerError, "unable to find serialier with content-type %s", mediaType) return nil, httperr.Newf(http.StatusInternalServerError, "unable to find serialier with content-type %s", mediaType)
@ -156,19 +139,17 @@ func tryDecodeObj(
return nil, httperr.Wrap(http.StatusInternalServerError, "decode obj", err) return nil, httperr.Wrap(http.StatusInternalServerError, "decode obj", err)
} }
return obj.(metav1.Object), nil return obj, nil
} }
func handleObj(r *http.Request, obj metav1.Object, resources map[string]metav1.Object) (metav1.Object, error) { func handleObj(r *http.Request, obj runtime.Object, resources map[string]runtime.Object) (runtime.Object, error) {
switch r.Method { switch r.Method {
case http.MethodGet: case http.MethodGet:
obj = resources[r.URL.Path] obj = resources[r.URL.Path]
case http.MethodPost, http.MethodPut: case http.MethodPost, http.MethodPut:
resources[path.Join(r.URL.Path, obj.GetName())] = obj resources[path.Join(r.URL.Path, obj.(metav1.Object).GetName())] = obj
case http.MethodDelete: case http.MethodDelete:
if _, ok := resources[r.URL.Path]; !ok { obj = resources[r.URL.Path]
return nil, httperr.Newf(http.StatusNotFound, "no resource with path %q", r.URL.Path)
}
delete(resources, r.URL.Path) delete(resources, r.URL.Path)
default: default:
return nil, httperr.New(http.StatusMethodNotAllowed, "check source code for methods supported") return nil, httperr.New(http.StatusMethodNotAllowed, "check source code for methods supported")
@ -177,7 +158,18 @@ func handleObj(r *http.Request, obj metav1.Object, resources map[string]metav1.O
return obj, nil return obj, nil
} }
func encodeObj(w http.ResponseWriter, r *http.Request, obj metav1.Object) error { func newNotFoundStatus(path string) runtime.Object {
status := &metav1.Status{
Status: metav1.StatusFailure,
Message: fmt.Sprintf("couldn't find object for path %q", path),
Reason: metav1.StatusReasonNotFound,
Code: http.StatusNotFound,
}
status.APIVersion, status.Kind = metav1.SchemeGroupVersion.WithKind("Status").ToAPIVersionAndKind()
return status
}
func encodeObj(w http.ResponseWriter, r *http.Request, obj runtime.Object) error {
if r.Method == http.MethodDelete { if r.Method == http.MethodDelete {
return nil return nil
} }