Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions lib/kube/proxy/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -719,8 +719,36 @@ func (f *Forwarder) writeResponseErrorToBody(rw http.ResponseWriter, respErr err
http.Error(rw, respErr.Error(), http.StatusInternalServerError)
}

// formatStatusResponseError formats the error response into a kube Status object.
// formatForwardResponseError handles errors returned from requests to the Kubernetes API.
// Any errors produced as a result of a GOAWAY request are forwarded to users as [http.StatusTooManyRequests]
// with a Retry-After header set to inform clients that they should retry the request. All
// other errors are formatted as a [metav1.Status] and written to the [http.ResponseWriter].
func (f *Forwarder) formatStatusResponseError(rw http.ResponseWriter, respErr error) {
// This detects failed requests that were terminated by the server due to GOAWAY. There
// is no direct way to detect these errors. No exported constants or error types exist from the
// standard library, see https://github.com/golang/net/blob/5ac9daca088ab4f378d7df849f6c7d28bea86071/http2/transport.go#L694.
// When a failed request is found, we return a response that indicates to clients that they
// should retry the request themselves.
if errString := respErr.Error(); strings.HasSuffix(errString, `after Request.Body was written; define Request.GetBody to avoid this error`) &&
strings.Contains(errString, `http2: Transport: cannot retry err`) {

data, err := runtime.Encode(globalKubeCodecs.LegacyCodec(), &kubeerrors.NewTooManyRequests("Connection closed by upstream Kubernetes server", 1).ErrStatus)
if err != nil {
f.log.WarnContext(f.ctx, "Failed encoding error into kube Status object", "error", err)
trace.WriteError(rw, respErr)
return
}

rw.Header().Set("Retry-After", "1")
rw.Header().Set(responsewriters.ContentTypeHeader, "application/json")
rw.WriteHeader(http.StatusTooManyRequests)

if _, err := rw.Write(data); err != nil && !utils.IsOKNetworkError(err) {
f.log.WarnContext(f.ctx, "Failed writing kube error response body", "error", err)
}
return
}

code := trace.ErrorToCode(respErr)
status := &metav1.Status{
Status: metav1.StatusFailure,
Expand All @@ -743,7 +771,7 @@ func (f *Forwarder) formatStatusResponseError(rw http.ResponseWriter, respErr er
// `Error from server (InternalError): an error on the server ("unknown")
// has prevented the request from succeeding`` instead of the correct reason.
rw.WriteHeader(trace.ErrorToCode(respErr))
if _, err := rw.Write(data); err != nil {
if _, err := rw.Write(data); err != nil && !utils.IsOKNetworkError(err) {
f.log.WarnContext(f.ctx, "Failed writing kube error response body", "error", err)
}
}
Expand Down
187 changes: 185 additions & 2 deletions lib/kube/proxy/forwarder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@
package proxy

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
Expand All @@ -40,8 +45,11 @@ import (
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"
kubeerrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/transport"
Expand Down Expand Up @@ -1023,7 +1031,7 @@ func mockAuthCtx(t *testing.T, kubeCluster string, isRemote bool) authContext {
func TestKubeFwdHTTPProxyEnv(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
f := newMockForwader(ctx, t)
f := newMockForwarder(ctx, t)

authCtx := mockAuthCtx(t, "kube-cluster", false)

Expand Down Expand Up @@ -1134,7 +1142,7 @@ func TestKubeFwdHTTPProxyEnv(t *testing.T) {
require.Equal(t, uint32(2), atomic.LoadUint32(&kubeAPICallCount))
}

func newMockForwader(ctx context.Context, t *testing.T) *Forwarder {
func newMockForwarder(ctx context.Context, t *testing.T) *Forwarder {
clock := clockwork.NewFakeClock()
cachedTransport, err := utils.NewFnCache(utils.FnCacheConfig{
TTL: transportCacheTTL,
Expand Down Expand Up @@ -1703,3 +1711,178 @@ func TestForwarderTLSConfigCAs(t *testing.T) {
})
require.True(t, getConnTLSRootsCalled)
}

func TestGOAWAYHandling(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context())
t.Cleanup(cancel)
f := newMockForwarder(ctx, t)

cert, err := tls.X509KeyPair(fixtures.LocalhostCert, fixtures.LocalhostKey)
require.NoError(t, err)

ln, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)

// Launch a server that replies with a GOAWAY.
gs := goawayServer{
listener: ln,
tlsConfig: &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{http2.NextProtoTLS},
},
}
t.Cleanup(func() { require.NoError(t, gs.Close()) })

go func() { require.NoError(t, gs.Serve()) }()

// Insert a fake Kubernetes cluster that forwards requests to the GOAWAY server above.
f.clusterDetails = map[string]*kubeDetails{
"kube-cluster": {
kubeCreds: &staticKubeCreds{
targetAddr: gs.URL(),
tlsConfig: gs.tlsConfig,
transport: &http2.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
},
},
},
}

// Create a user session.
authCtx := mockAuthCtx(t, "kube-cluster", false)
sess, err := f.newClusterSession(ctx, authCtx)
require.NoError(t, err)
t.Cleanup(sess.close)

fwd, err := f.makeSessionForwarder(sess)
require.NoError(t, err)

// Forward all requests for this session to the GOAWAY server.
forwarderServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.URL, err = url.Parse(gs.URL())
require.NoError(t, err)
fwd.ServeHTTP(w, r)
}))

t.Cleanup(forwarderServer.Close)

// Issue a request that will be forwarded to the GOAWAY server and validate
// that the GOAWAY is caught and a 429 is returned to clients.
body := bytes.NewBuffer([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9})
req, err := http.NewRequest("GET", forwarderServer.URL, body)
require.NoError(t, err)
resp, err := forwarderServer.Client().Do(req)
require.NoError(t, err)

t.Cleanup(func() { assert.NoError(t, resp.Body.Close()) })
require.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
require.Equal(t, "1", resp.Header.Get("Retry-After"))

var status metav1.Status
err = json.NewDecoder(resp.Body).Decode(&status)
require.NoError(t, err)
require.Equal(t, metav1.StatusReasonTooManyRequests, status.Reason)
}

// goawayServer is a fake [http2.Server] that terminates all received client
// connections in the same manner that a Kubernetes API Server would if
// it closed the connection as a result of the GOAWAY chance being exceeded.
type goawayServer struct {
listener net.Listener
tlsConfig *tls.Config
}

// URL returns the address clients should use to connect to the server.
func (g *goawayServer) URL() string {
return "https://" + g.listener.Addr().String()
}

// Serve listens and handles connections in a blocking manner. Call
// [Close] to terminate handling new connections and unblock.
func (g *goawayServer) Serve() error {
tlsLn := tls.NewListener(g.listener, g.tlsConfig)

for {
conn, err := tlsLn.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
return err
}

if err := g.handleConn(conn); err != nil {
return err
}
}
}

// Close terminates the server and unblocks any calls to [Serve].
func (g *goawayServer) Close() error {
return g.listener.Close()
}

// handleConn performs the initial HTTP/2 message exchange and then
// replies with a GOAWAY before closing the connection.
func (g *goawayServer) handleConn(conn net.Conn) error {
defer conn.Close()

// Consume and validate the client is communicating HTTP2
// before consuming any frames.
preface := make([]byte, len(http2.ClientPreface))
n, err := io.ReadFull(conn, preface)
if err != nil {
return err
}

if n != len(http2.ClientPreface) {
return errors.New("http2 client preface not fully provided")
}

if bytes.Contains(preface, []byte("HTTP/1.1")) {
return errors.New("expected HTTP2 in client preface, got HTTP 1.1")
}

// Start consuming HTTP2 frames
framer := http2.NewFramer(conn, conn)
framer.ReadMetaHeaders = hpack.NewDecoder(4096, nil)

// The first frame is the client SETTING
if _, err := framer.ReadFrame(); err != nil {
return err
}

// Respond with the server SETTINGS
if err := framer.WriteSettings(); err != nil {
return err
}

// Keep reading frames until the [http2.MetaHeadersFrame] is received and then
// issue a GOAWAY.
for {
frame, err := framer.ReadFrame()
if err != nil {
return err
}

switch f := frame.(type) {
case *http2.SettingsFrame:
if f.IsAck() {
continue
}
if err := framer.WriteSettingsAck(); err != nil {
return err
}
case *http2.MetaHeadersFrame:
if err := framer.WriteGoAway(f.StreamID-1, http2.ErrCodeNo, nil); err != nil {
return err
}

return nil
default:
continue
}
}
}
Loading