Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions lib/kube/proxy/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1813,11 +1813,22 @@ func (f *Forwarder) portForward(authCtx *authContext, w http.ResponseWriter, req
}

auditSent := map[string]bool{} // Set of `addr`. Can be multiple ports on single call. Using bool to simplify the check.
var auditSentMu sync.Mutex
onPortForward := func(addr string, success bool) {
if !sess.isLocalKubernetesCluster || auditSent[addr] {
if !sess.isLocalKubernetesCluster {
return
}
auditSent[addr] = true

auditSentMu.Lock()
isAuditSent := auditSent[addr]
if !isAuditSent {
auditSent[addr] = true
}
auditSentMu.Unlock()
if isAuditSent {
return
}

portForward := &apievents.PortForward{
Metadata: apievents.Metadata{
Type: events.PortForwardEvent,
Expand Down
133 changes: 127 additions & 6 deletions lib/kube/proxy/portforward_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ import (
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
kubeerrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/httpstream"
Expand Down Expand Up @@ -149,7 +152,7 @@ func TestPortForwardKubeService(t *testing.T) {
podName: podName,
podNamespace: podNamespace,
restConfig: config,
podPort: 80,
podPorts: []int{80},
stopCh: stopCh,
readyCh: readyCh,
})
Expand Down Expand Up @@ -188,7 +191,6 @@ func TestPortForwardKubeService(t *testing.T) {
// Dial a connection to localPort.
ports, err := fw.GetPorts()
require.NoError(t, err)
require.Len(t, ports, 1)

conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", ports[0].Local))
require.NoError(t, err)
Expand All @@ -208,6 +210,121 @@ func TestPortForwardKubeService(t *testing.T) {
}
}

func TestPortForwardKubeServiceMultiPort(t *testing.T) {
t.Parallel()

kubeMock, err := testingkubemock.NewKubeAPIMock()
require.NoError(t, err)
t.Cleanup(func() { kubeMock.Close() })

// creates a Kubernetes service with a configured cluster pointing to mock api server
testCtx := SetupTestContext(
t.Context(),
t,
TestConfig{
Clusters: []KubeClusterConfig{{Name: kubeCluster, APIEndpoint: kubeMock.URL}},
},
)
t.Cleanup(func() { require.NoError(t, testCtx.Close()) })

// create a user with access to kubernetes (kubernetes_user and kubernetes_groups specified)
user, _ := testCtx.CreateUserAndRole(
testCtx.Context,
t,
username,
RoleSpec{
Name: roleName,
KubeUsers: roleKubeUsers,
KubeGroups: roleKubeGroups,
})

// generate a kube client with user certs for auth
_, config := testCtx.GenTestKubeClientTLSCert(
t,
user.GetName(),
kubeCluster,
)
require.NoError(t, err)

// Create 100 ports.
const portCount = 100
podPorts := make([]int, 0, portCount)
for port := 80; port < 80+portCount; port++ {
podPorts = append(podPorts, port)
}

readyCh := make(chan struct{})
stopCh := make(chan struct{})

forwarder := spdyPortForwardClientBuilder(t, portForwardRequestConfig{
podName: podName,
podNamespace: podNamespace,
restConfig: config,
podPorts: podPorts,
stopCh: stopCh,
readyCh: readyCh,
})

forwarderCh := make(chan error, 1)
t.Cleanup(func() {
// Graceful shutdown.
close(stopCh)

forwarder.Close()
})
go func() { forwarderCh <- forwarder.ForwardPorts() }()

// Wait for port forwarding to be ready.
select {
case <-time.After(5 * time.Second):
t.Fatal("Timeout waiting for port forwarding")
case <-readyCh:
}

// Port forwarding is ready.
portPairs, err := forwarder.GetPorts()
require.NoError(t, err)

g, _ := errgroup.WithContext(t.Context())
for _, portPair := range portPairs {
p := portPair

g.Go(func() error {

conn, err := net.Dial("tcp", net.JoinHostPort("localhost", strconv.Itoa(int(p.Local))))
if err != nil {
return fmt.Errorf("unable to dial local port %d: %w", p.Local, err)
}
defer conn.Close()

testData := []byte(fmt.Sprintf("test-data-port-%d", p.Local))
_, err = conn.Write(testData)
if err != nil {
return fmt.Errorf("unable to write local port %d: %w", p.Local, err)
}

// Read from source.
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil {
return fmt.Errorf("unable to read from local port %d: %w", p.Local, err)
}

expected := fmt.Sprintf("%s%s%s", testingkubemock.PortForwardPayload, podName, string(testData))
if !strings.Contains(string(buf[:n]), expected) {
return fmt.Errorf("unexpected response on local port %d: expect %q, actual %q",
p.Local, string(buf[:n]), expected)
}

return nil
})
}

err = g.Wait()
require.NoError(t, err, "Port forwarding checks failed")

}

func portforwardURL(namespace, podName string, host string, query string) (*url.URL, error) {
u, err := url.Parse(host)
if err != nil {
Expand All @@ -227,7 +344,11 @@ func spdyPortForwardClientBuilder(t *testing.T, req portForwardRequestConfig) po
u, err := portforwardURL(req.podNamespace, req.podName, req.restConfig.Host, "")
require.NoError(t, err)
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, http.MethodPost, u)
fw, err := portforward.New(dialer, []string{fmt.Sprintf("%d:%d", 0, req.podPort)}, req.stopCh, req.readyCh, os.Stdout, os.Stdin)
ports := make([]string, len(req.podPorts))
for n, port := range req.podPorts {
ports[n] = fmt.Sprintf("0:%d", port)
}
fw, err := portforward.New(dialer, ports, req.stopCh, req.readyCh, os.Stdout, os.Stdin)
require.NoError(t, err)
return fw
}
Expand All @@ -249,8 +370,8 @@ type portForwardRequestConfig struct {
podName string
// podNamespace is the pod namespace.
podNamespace string
// podPort is the target port for the pod.
podPort int
// podPorts is the target port for the pod.
podPorts []int
// stopCh is the channel used to manage the port forward lifecycle
stopCh <-chan struct{}
// readyCh communicates when the tunnel is ready to receive traffic
Expand Down Expand Up @@ -534,7 +655,7 @@ func TestPortForwardUnderlyingProtocol(t *testing.T) {
podName: podName,
podNamespace: podNamespace,
restConfig: config,
podPort: 80,
podPorts: []int{80},
stopCh: stopCh,
readyCh: readyCh,
})
Expand Down
113 changes: 97 additions & 16 deletions lib/kube/proxy/testing/kube_server/kube_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -892,44 +892,125 @@ func (s *KubeMockServer) portforward(w http.ResponseWriter, req *http.Request, p
}
upgrader := spdystream.NewResponseUpgraderWithPings(defaults.HighResPollingPeriod)
conn = upgrader.UpgradeResponse(w, req, httpStreamReceived(req.Context(), streamChan))

}

if conn == nil {
err = trace.ConnectionProblem(nil, "unable to upgrade connection")
return nil, err
}
defer conn.Close()
var (
data httpstream.Stream
errStream httpstream.Stream
)

// Create a context for managing goroutines.
ctx, cancel := context.WithCancel(req.Context())
defer cancel()

// Wait for all active port forwards to complete before returning.
var wg sync.WaitGroup
defer wg.Wait()

// Get pod name
podName := p.ByName("name")

type portStream struct {
data httpstream.Stream
error httpstream.Stream
processing bool // Prevent duplicate handlers
}

portStreams := make(map[string]*portStream)
var streamsMu sync.Mutex

for {
select {
case <-ctx.Done():
s.log.InfoContext(ctx, "Context canceled")
return nil, nil
case <-conn.CloseChan():
s.log.InfoContext(ctx, "Connection closed")
return nil, nil
case stream := <-streamChan:
port := stream.Headers().Get(portHeader)
if port == "" {
s.log.WarnContext(ctx, "Skipping a stream without a port header")
continue
}

streamsMu.Lock()
if _, ok := portStreams[port]; !ok {
portStreams[port] = &portStream{}
}

ps := portStreams[port]

switch stream.Headers().Get(StreamType) {
case StreamTypeError:
errStream = stream
ps.error = stream
case StreamTypeData:
data = stream
ps.data = stream
default:
s.log.WarnContext(ctx, "Unknown stream type", "type", stream.Headers().Get(StreamType))
}
}
if errStream != nil && data != nil {
break

// Check whether the port is ready to process.
if ps.data != nil && ps.error != nil && !ps.processing {
ps.processing = true

// Process each port.
// Use a separate goroutine with each port for concurrency testing.
wg.Add(1)
go s.handlePortForward(ctx, &wg, port, podName, ps.data, ps.error)
}

streamsMu.Unlock()
}
}
}

// handlePortForward reads and writes to a port-forward stream.
func (s *KubeMockServer) handlePortForward(ctx context.Context, wg *sync.WaitGroup, port string, podName string, dataStream, errorStream httpstream.Stream) {
defer wg.Done()
defer errorStream.Close()

// Unblock stream read when the context cancels.
stop := context.AfterFunc(ctx, func() { dataStream.Close() })
defer func() {
// Ensure that dataStream closes only once.
// httpstream.Stream.Close is not idempotent.
// stop() is true when AfterFunc hasn't run.
if stop() {
dataStream.Close()
}
}()

// Read from source.
buf := make([]byte, 1024)
n, err := data.Read(buf)
if err != nil {
errStream.Write([]byte(err.Error()))
return nil, nil
n, readErr := dataStream.Read(buf)

// Process any data received, regardless of error.
// Behavior is based on the io.Reader contract.
// Handles the case where Read returns data and io.EOF.
if n > 0 {
// Write to target.
_, writeErr := fmt.Fprint(dataStream, PortForwardPayload, podName, string(buf[:n]))
if writeErr != nil {
s.log.ErrorContext(ctx, "Unable to write response", "error", writeErr)
if _, errWriteErr := errorStream.Write([]byte(writeErr.Error())); errWriteErr != nil {
s.log.ErrorContext(ctx, "Unable to write error", "error", errWriteErr)
}
}
return
}
fmt.Fprint(data, PortForwardPayload, p.ByName("name"), string(buf[:n]))
return nil, nil

// Check for read error.
if readErr != nil && !errors.Is(readErr, io.EOF) {
s.log.ErrorContext(ctx, "Read error", "port", port, "error", readErr)
if _, writeErr := errorStream.Write([]byte(readErr.Error())); writeErr != nil {
s.log.ErrorContext(ctx, "Unable to write error", "error", writeErr)
}
return
}

s.log.InfoContext(ctx, "Port forward completed", "port", port)
}

// httpStreamReceived is the httpstream.NewStreamHandler for port
Expand Down
4 changes: 2 additions & 2 deletions lib/kube/proxy/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ func (c *TestContext) startKubeServices(t *testing.T) {
go func() {
err := c.KubeServer.Serve(c.kubeServerListener)
// ignore server closed error returned when .Close is called.
if errors.Is(err, http.ErrServerClosed) {
if errors.Is(err, http.ErrServerClosed) || errors.Is(err, net.ErrClosed) {
return
}
assert.NoError(t, err)
Expand All @@ -414,7 +414,7 @@ func (c *TestContext) startKubeServices(t *testing.T) {
go func() {
err := c.KubeProxy.Serve(c.kubeProxyListener)
// ignore server closed error returned when .Close is called.
if errors.Is(err, http.ErrServerClosed) {
if errors.Is(err, http.ErrServerClosed) || errors.Is(err, net.ErrClosed) {
return
}
assert.NoError(t, err)
Expand Down
Loading