Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions lib/httplib/reverseproxy/reverse_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package reverseproxy

import (
"context"
"net/http"
"net/http/httputil"
"net/url"
Expand Down Expand Up @@ -107,6 +108,32 @@ func New(opts ...Option) (*Forwarder, error) {
return fwd, nil
}

// ServeHTTP implements the http.Handler interface for the Forwarder.
// It sets the ServerContextKey to nil to prevent the reverse proxy to panic
// when the request is served. The panic happens when the request is
// canceled by the client instead of the server, which is a common case
// when the reverse proxy is used to forward requests to long-running
// operations (e.g. kubernetes watch streams).
// https://cs.opensource.google/go/go/+/refs/tags/go1.24.4:src/net/http/httputil/reverseproxy.go;l=556-574;drc=e64f7ef03fdfa1c0d847c21b16c9302cc824e79b
// When the ServerContextKey is set to nil, the reverse proxy will not
// attempt to panic when the request is canceled, and will instead
// return. This allows any upstream logic to continue and clean up
// resources instead of having to handle the panic recovery. This
// is particularly important for Kubernetes Watch streams, where
// a substantial number of goroutines are spawned to handle
// the watch stream, and we want to clean them up gracefully
// instead leaving them hanging around because of a panic.
func (f *Forwarder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(
context.WithValue(
r.Context(),
http.ServerContextKey,
nil,
),
)
f.ReverseProxy.ServeHTTP(w, r)
}

// Option is a functional option for the forwarder.
type Option func(*Forwarder)

Expand Down
129 changes: 129 additions & 0 deletions lib/httplib/reverseproxy/reverse_proxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Teleport
* Copyright (C) 2023 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package reverseproxy

import (
"context"
"io"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"sync"
"sync/atomic"
"testing"

"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/utils"
)

// TestRequestCancelWithoutPanic tests that canceling a request does not
// cause a panic in the reverse proxy handler. This is important to ensure
// that the reverse proxy can handle client disconnects gracefully without
// crashing the server.
// It simulates a long-running request and then cancels it, ensuring that
// frontend doesn't panic, the backend handler receives the cancelation,
// and all resources are cleaned up properly.
func TestRequestCancelWithoutPanic(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel) // Ensure the context is canceled after the test.

var numberOfActiveRequests atomic.Int64

wg := &sync.WaitGroup{}
wg.Add(1)

backend := httptest.NewServer(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer wg.Done()

numberOfActiveRequests.Add(1)
defer numberOfActiveRequests.Add(-1)

w.WriteHeader(http.StatusOK)
w.Write([]byte("Hello, world!"))
// Ensure the response is flushed to the client immediately.
w.(http.Flusher).Flush()

// Simulate a long-running request.
select {
case <-r.Context().Done():
// Request was canceled, do nothing.
return
case <-ctx.Done():
// Test context was canceled. At this point, the test failed
panic("test context canceled before request completed")
}
},
))

t.Cleanup(backend.Close)

backendURL, err := url.Parse(backend.URL)
require.NoError(t, err)
proxyHandler := newSingleHostReverseProxy(backendURL)

wg.Add(1)
frontend := httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
numberOfActiveRequests.Add(1)
proxyHandler.ServeHTTP(w, r)
// Place the wg.Done() call here to ensure that
// if the panic occurs, it will never be called.
numberOfActiveRequests.Add(-1)
wg.Done()
}),
)

reqCtx, reqCancel := context.WithCancel(ctx)
getReq, _ := http.NewRequestWithContext(reqCtx, http.MethodGet, frontend.URL, nil)

frontendClient := frontend.Client()
res, err := frontendClient.Do(getReq)
require.NoError(t, err)
t.Cleanup(func() {
io.Copy(io.Discard, res.Body) // Drain the body to avoid resource leaks.
_ = res.Body.Close() // Ensure we close the response body to avoid resource leaks.
})

require.Equal(t, http.StatusOK, res.StatusCode)

data := make([]byte, 20)
n, err := res.Body.Read(data)
require.NoError(t, err)
// Ensure we read the expected response.
require.Equal(t, "Hello, world!", string(data[:n]))

require.Equal(t, int64(2), numberOfActiveRequests.Load(), "There should two active handlers at this point.")

reqCancel() // Cancel the request to simulate client disconnect.
wg.Wait() // Wait for the backend handler to finish.

require.Equal(t, int64(0), numberOfActiveRequests.Load(), "There should be no active handlers after the request is canceled.")

}

func newSingleHostReverseProxy(target *url.URL) *Forwarder {
return &Forwarder{
ReverseProxy: httputil.NewSingleHostReverseProxy(target),
log: utils.NewLogger(),
}

}
21 changes: 10 additions & 11 deletions lib/kube/proxy/resource_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package proxy

import (
"bytes"
"context"
"io"
"net/http"
"strings"
Expand Down Expand Up @@ -197,22 +198,22 @@ func (f *Forwarder) listResourcesWatcher(req *http.Request, w http.ResponseWrite
// push events that show ephemeral containers were started if there
// are any ephemeral containers waiting to be created for this pod
// by this user
done := make(chan struct{})
var wg sync.WaitGroup
ctx, cancel := context.WithCancel(req.Context())
if podName := isRequestTargetedToPod(req, sess.apiResource); podName != "" && ok {
wg.Add(1)
go func() {
defer wg.Done()

f.sendEphemeralContainerEvents(done, req, rw, sess, podName)
f.sendEphemeralContainerEvents(ctx, rw, sess, podName)
}()
}

// Forwards the request to the target cluster.
sess.forwarder.ServeHTTP(rw, req)
// Wait for the fake event pushing goroutine to finish
close(done)
cancel()
wg.Wait()

// Once the request terminates, close the watcher and waits for resources
// cleanup.
err = rw.Close()
Expand All @@ -223,14 +224,14 @@ func (f *Forwarder) listResourcesWatcher(req *http.Request, w http.ResponseWrite
// each 5s from cache and see if they match the user and pod and namespace.
// If any match exists, it will push a fake event to the watcher stream to trick
// kubectl into creating the exec session.
func (f *Forwarder) sendEphemeralContainerEvents(done <-chan struct{}, req *http.Request, rw *responsewriters.WatcherResponseWriter, sess *clusterSession, podName string) {
func (f *Forwarder) sendEphemeralContainerEvents(ctx context.Context, rw *responsewriters.WatcherResponseWriter, sess *clusterSession, podName string) {
const backoff = 5 * time.Second
sentDebugContainers := map[string]struct{}{}
ticker := time.NewTicker(backoff)
defer ticker.Stop()
for {
wcs, err := f.getUserEphemeralContainersForPod(
req.Context(),
ctx,
sess.User.GetName(),
sess.kubeClusterName,
sess.apiResource.namespace,
Expand All @@ -245,7 +246,7 @@ func (f *Forwarder) sendEphemeralContainerEvents(done <-chan struct{}, req *http
if _, ok := sentDebugContainers[wc.Spec.ContainerName]; ok {
continue
}
evt, err := f.getPatchedPodEvent(req.Context(), sess, wc)
evt, err := f.getPatchedPodEvent(ctx, sess, wc)
if err != nil {
f.log.WithError(err).Warn("error pushing pod event")
continue
Expand All @@ -254,15 +255,13 @@ func (f *Forwarder) sendEphemeralContainerEvents(done <-chan struct{}, req *http
// push the event to the client
// this will lock until the event is pushed or the
// request context is done.
rw.PushVirtualEventToClient(req.Context(), evt)
rw.PushVirtualEventToClient(ctx, evt)
}

// wait a bit before querying the cache again, or return
// if the request has finished
select {
case <-req.Context().Done():
return
case <-done:
case <-ctx.Done():
return
case <-ticker.C:
}
Expand Down
Loading