diff --git a/integration/kube_integration_test.go b/integration/kube_integration_test.go index 715fddb287cd0..87912d5018972 100644 --- a/integration/kube_integration_test.go +++ b/integration/kube_integration_test.go @@ -29,6 +29,7 @@ import ( "os" "os/user" "strconv" + "sync" "testing" "time" @@ -37,6 +38,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/http2" + "golang.org/x/sync/errgroup" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -1617,8 +1619,11 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) { out := &bytes.Buffer{} - go func() { - err = kubeExec(proxyClientConfig, kubeExecArgs{ + group := &errgroup.Group{} + + // Start the main session. + group.Go(func() error { + err := kubeExec(proxyClientConfig, kubeExecArgs{ podName: pod.Name, podNamespace: pod.Namespace, container: pod.Spec.Containers[0].Name, @@ -1627,9 +1632,8 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) { tty: true, stdin: term, }) - - require.NoError(t, err) - }() + return trace.Wrap(err) + }) // We need to wait for the exec request to be handled here for the session to be // created. Sadly though the k8s API doesn't give us much indication of when that is. @@ -1647,19 +1651,25 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) { return true }, 10*time.Second, time.Second) - participantStdinR, participantStdinW := io.Pipe() - participantStdoutR, participantStdoutW := io.Pipe() + participantStdinR, participantStdinW, err := os.Pipe() + require.NoError(t, err) + participantStdoutR, participantStdoutW, err := os.Pipe() + require.NoError(t, err) + streamsMu := &sync.Mutex{} streams := make([]*client.KubeSession, 0, 3) observerCaptures := make([]*bytes.Buffer, 0, 2) albProxy := helpers.MustStartMockALBProxy(t, teleport.Config.Proxy.WebAddr.Addr) - t.Run("join peer by KubeProxyAddr", func(t *testing.T) { + // join peer by KubeProxyAddr + group.Go(func() error { tc, err := teleport.NewClient(helpers.ClientConfig{ Login: hostUsername, Cluster: helpers.Site, Host: Host, }) - require.NoError(t, err) + if err != nil { + return trace.Wrap(err) + } tc.Stdin = participantStdinR tc.Stdout = participantStdoutW @@ -1670,27 +1680,40 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) { KubeUsers: kubeUsers, KubeGroups: kubeGroups, }, tc, session, types.SessionPeerMode) - require.NoError(t, err) + if err != nil { + return trace.Wrap(err) + } + streamsMu.Lock() streams = append(streams, stream) + streamsMu.Unlock() + stream.Wait() + // close participant stdout so that we can read it after till EOF + participantStdoutW.Close() + return nil }) - t.Run("join observer by WebProxyAddr", func(t *testing.T) { + // join observer by WebProxyAddr + group.Go(func() error { stream, capture := kubeJoinByWebAddr(t, teleport, participantUsername, kubeUsers, kubeGroups) + streamsMu.Lock() streams = append(streams, stream) observerCaptures = append(observerCaptures, capture) + streamsMu.Unlock() + stream.Wait() + return nil }) - t.Run("join observer with ALPN conn upgrade", func(t *testing.T) { + + // join observer with ALPN conn upgrade + group.Go(func() error { stream, capture := kubeJoinByALBAddr(t, teleport, participantUsername, kubeUsers, kubeGroups, albProxy.Addr().String()) + streamsMu.Lock() streams = append(streams, stream) observerCaptures = append(observerCaptures, capture) + streamsMu.Unlock() + stream.Wait() + return nil }) - require.Len(t, observerCaptures, 2) - require.Len(t, streams, 3) - for _, stream := range streams { - defer stream.Close() - } - // We wait again for the second user to finish joining the session. // We allow a bit of time to pass here to give the session manager time to recognize the // new IO streams of the second client. @@ -1704,25 +1727,26 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) { // Terminate the session after a moment to allow for the IO to reach the second client. time.AfterFunc(5*time.Second, func() { - term.Type("\aexit\n\r\a") - participantStdoutW.Close() + // send exit command to close the session + term.Type("exit 0\n\r\a") }) - t.Run("verify output", func(t *testing.T) { - // Verify peer. - participantOutput, err := io.ReadAll(participantStdoutR) - require.NoError(t, err) - require.Contains(t, string(participantOutput), "hi from term") + // wait for all clients to finish + require.NoError(t, group.Wait()) - // Verify original session. - require.Contains(t, out.String(), "hi from peer") + // Verify peer. + participantOutput, err := io.ReadAll(participantStdoutR) + require.NoError(t, err) + require.Contains(t, string(participantOutput), "hi from term") - // Verify observers. - for _, capture := range observerCaptures { - require.Contains(t, capture.String(), "hi from peer") - require.Contains(t, capture.String(), "hi from term") - } - }) + // Verify original session. + require.Contains(t, out.String(), "hi from peer") + + // Verify observers. + for _, capture := range observerCaptures { + require.Contains(t, capture.String(), "hi from peer") + require.Contains(t, capture.String(), "hi from term") + } } func kubeJoinByWebAddr(t *testing.T, teleport *helpers.TeleInstance, username string, kubeUsers, kubeGroups []string) (*client.KubeSession, *bytes.Buffer) { diff --git a/lib/client/kubesession.go b/lib/client/kubesession.go index 8b07e99b0b1c8..f31d99add0249 100644 --- a/lib/client/kubesession.go +++ b/lib/client/kubesession.go @@ -22,6 +22,7 @@ import ( "encoding/json" "fmt" "io" + "sync" "time" "github.com/gorilla/websocket" @@ -45,6 +46,7 @@ type KubeSession struct { ctx context.Context cancel context.CancelFunc meta types.SessionTracker + wg sync.WaitGroup } // NewKubeSession joins a live kubernetes session. @@ -63,7 +65,7 @@ func NewKubeSession(ctx context.Context, tc *TeleportClient, meta types.SessionT fmt.Printf("Joining session with participant mode: %v. \n\n", mode) - ws, resp, err := dialer.Dial(joinEndpoint, nil) + ws, resp, err := dialer.DialContext(ctx, joinEndpoint, nil) if resp != nil && resp.Body != nil { defer resp.Body.Close() } @@ -111,7 +113,7 @@ func NewKubeSession(ctx context.Context, tc *TeleportClient, meta types.SessionT go handleOutgoingResizeEvents(ctx, stream, term) go handleIncomingResizeEvents(stream, term) - s := &KubeSession{stream, term, ctx, cancel, meta} + s := &KubeSession{stream, term, ctx, cancel, meta, sync.WaitGroup{}} err = s.handleMFA(ctx, tc, mode, stdout) if err != nil { return nil, trace.Wrap(err) @@ -203,7 +205,10 @@ func (s *KubeSession) handleMFA(ctx context.Context, tc *TeleportClient, mode ty // pipeInOut starts background tasks that copy input to and from the terminal. func (s *KubeSession) pipeInOut(stdout io.Writer, enableEscapeSequences bool, mode types.SessionParticipantMode) { + // wait for the session to copy everything + s.wg.Add(1) go func() { + defer s.wg.Done() defer s.cancel() _, err := io.Copy(stdout, s.stream) if err != nil { @@ -231,7 +236,8 @@ func (s *KubeSession) pipeInOut(stdout io.Writer, enableEscapeSequences bool, mo // Wait waits for the session to finish. func (s *KubeSession) Wait() { - <-s.ctx.Done() + // Wait for the session to copy everything into stdout + s.wg.Wait() } // Close sends a close request to the other end and waits it to gracefully terminate the connection. @@ -240,7 +246,7 @@ func (s *KubeSession) Close() error { return trace.Wrap(err) } - <-s.ctx.Done() + s.wg.Wait() return trace.Wrap(s.Detach()) } diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index fe023a239c2f8..8fd78d5ee5209 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -1273,9 +1273,9 @@ func (f *Forwarder) remoteJoin(ctx *authContext, w http.ResponseWriter, req *htt NetDialContext: sess.DialWithContext, } - headers := req.Header + headers := http.Header{} if impersonationHeaders { - if headers, err = auth.IdentityForwardingHeaders(req.Context(), req.Header); err != nil { + if headers, err = auth.IdentityForwardingHeaders(req.Context(), headers); err != nil { return nil, trace.Wrap(err) } } @@ -1288,6 +1288,10 @@ func (f *Forwarder) remoteJoin(ctx *authContext, w http.ResponseWriter, req *htt wsTarget, respTarget, err := dialer.DialContext(req.Context(), url, headers) if err != nil { + if respTarget == nil { + return nil, trace.Wrap(err) + } + defer respTarget.Body.Close() msg, err := io.ReadAll(respTarget.Body) if err != nil { return nil, trace.Wrap(err) @@ -1297,7 +1301,6 @@ func (f *Forwarder) remoteJoin(ctx *authContext, w http.ResponseWriter, req *htt if err := json.Unmarshal(msg, &obj); err != nil { return nil, trace.Wrap(err) } - return obj, trace.Wrap(err) } defer wsTarget.Close()