Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 57 additions & 33 deletions integration/kube_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"os"
"os/user"
"strconv"
"sync"
"testing"
"time"

Expand All @@ -37,6 +38,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/http2"
"golang.org/x/sync/errgroup"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -1617,8 +1619,11 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) {

out := &bytes.Buffer{}

go func() {
err = kubeExec(proxyClientConfig, kubeExecArgs{
group := &errgroup.Group{}

// Start the main session.
group.Go(func() error {
err := kubeExec(proxyClientConfig, kubeExecArgs{
podName: pod.Name,
podNamespace: pod.Namespace,
container: pod.Spec.Containers[0].Name,
Expand All @@ -1627,9 +1632,8 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) {
tty: true,
stdin: term,
})

require.NoError(t, err)
}()
return trace.Wrap(err)
})

// We need to wait for the exec request to be handled here for the session to be
// created. Sadly though the k8s API doesn't give us much indication of when that is.
Expand All @@ -1647,19 +1651,25 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) {
return true
}, 10*time.Second, time.Second)

participantStdinR, participantStdinW := io.Pipe()
participantStdoutR, participantStdoutW := io.Pipe()
participantStdinR, participantStdinW, err := os.Pipe()
require.NoError(t, err)
participantStdoutR, participantStdoutW, err := os.Pipe()
require.NoError(t, err)
streamsMu := &sync.Mutex{}
streams := make([]*client.KubeSession, 0, 3)
observerCaptures := make([]*bytes.Buffer, 0, 2)
albProxy := helpers.MustStartMockALBProxy(t, teleport.Config.Proxy.WebAddr.Addr)

t.Run("join peer by KubeProxyAddr", func(t *testing.T) {
// join peer by KubeProxyAddr
group.Go(func() error {
tc, err := teleport.NewClient(helpers.ClientConfig{
Login: hostUsername,
Cluster: helpers.Site,
Host: Host,
})
require.NoError(t, err)
if err != nil {
return trace.Wrap(err)
}

tc.Stdin = participantStdinR
tc.Stdout = participantStdoutW
Expand All @@ -1670,27 +1680,40 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) {
KubeUsers: kubeUsers,
KubeGroups: kubeGroups,
}, tc, session, types.SessionPeerMode)
require.NoError(t, err)
if err != nil {
return trace.Wrap(err)
}
streamsMu.Lock()
streams = append(streams, stream)
streamsMu.Unlock()
stream.Wait()
// close participant stdout so that we can read it after till EOF
participantStdoutW.Close()
return nil
})

t.Run("join observer by WebProxyAddr", func(t *testing.T) {
// join observer by WebProxyAddr
group.Go(func() error {
stream, capture := kubeJoinByWebAddr(t, teleport, participantUsername, kubeUsers, kubeGroups)
streamsMu.Lock()
streams = append(streams, stream)
observerCaptures = append(observerCaptures, capture)
streamsMu.Unlock()
stream.Wait()
return nil
})
t.Run("join observer with ALPN conn upgrade", func(t *testing.T) {

// join observer with ALPN conn upgrade
group.Go(func() error {
stream, capture := kubeJoinByALBAddr(t, teleport, participantUsername, kubeUsers, kubeGroups, albProxy.Addr().String())
streamsMu.Lock()
streams = append(streams, stream)
observerCaptures = append(observerCaptures, capture)
streamsMu.Unlock()
stream.Wait()
return nil
})

require.Len(t, observerCaptures, 2)
require.Len(t, streams, 3)
for _, stream := range streams {
defer stream.Close()
}

// We wait again for the second user to finish joining the session.
// We allow a bit of time to pass here to give the session manager time to recognize the
// new IO streams of the second client.
Expand All @@ -1704,25 +1727,26 @@ func testKubeJoin(t *testing.T, suite *KubeSuite) {

// Terminate the session after a moment to allow for the IO to reach the second client.
time.AfterFunc(5*time.Second, func() {
term.Type("\aexit\n\r\a")
participantStdoutW.Close()
// send exit command to close the session
term.Type("exit 0\n\r\a")
})

t.Run("verify output", func(t *testing.T) {
// Verify peer.
participantOutput, err := io.ReadAll(participantStdoutR)
require.NoError(t, err)
require.Contains(t, string(participantOutput), "hi from term")
// wait for all clients to finish
require.NoError(t, group.Wait())

// Verify original session.
require.Contains(t, out.String(), "hi from peer")
// Verify peer.
participantOutput, err := io.ReadAll(participantStdoutR)
require.NoError(t, err)
require.Contains(t, string(participantOutput), "hi from term")

// Verify observers.
for _, capture := range observerCaptures {
require.Contains(t, capture.String(), "hi from peer")
require.Contains(t, capture.String(), "hi from term")
}
})
// Verify original session.
require.Contains(t, out.String(), "hi from peer")

// Verify observers.
for _, capture := range observerCaptures {
require.Contains(t, capture.String(), "hi from peer")
require.Contains(t, capture.String(), "hi from term")
}
}

func kubeJoinByWebAddr(t *testing.T, teleport *helpers.TeleInstance, username string, kubeUsers, kubeGroups []string) (*client.KubeSession, *bytes.Buffer) {
Expand Down
14 changes: 10 additions & 4 deletions lib/client/kubesession.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"encoding/json"
"fmt"
"io"
"sync"
"time"

"github.com/gorilla/websocket"
Expand All @@ -45,6 +46,7 @@ type KubeSession struct {
ctx context.Context
cancel context.CancelFunc
meta types.SessionTracker
wg sync.WaitGroup
}

// NewKubeSession joins a live kubernetes session.
Expand All @@ -63,7 +65,7 @@ func NewKubeSession(ctx context.Context, tc *TeleportClient, meta types.SessionT

fmt.Printf("Joining session with participant mode: %v. \n\n", mode)

ws, resp, err := dialer.Dial(joinEndpoint, nil)
ws, resp, err := dialer.DialContext(ctx, joinEndpoint, nil)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
}
Expand Down Expand Up @@ -111,7 +113,7 @@ func NewKubeSession(ctx context.Context, tc *TeleportClient, meta types.SessionT
go handleOutgoingResizeEvents(ctx, stream, term)
go handleIncomingResizeEvents(stream, term)

s := &KubeSession{stream, term, ctx, cancel, meta}
s := &KubeSession{stream, term, ctx, cancel, meta, sync.WaitGroup{}}
err = s.handleMFA(ctx, tc, mode, stdout)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -203,7 +205,10 @@ func (s *KubeSession) handleMFA(ctx context.Context, tc *TeleportClient, mode ty

// pipeInOut starts background tasks that copy input to and from the terminal.
func (s *KubeSession) pipeInOut(stdout io.Writer, enableEscapeSequences bool, mode types.SessionParticipantMode) {
// wait for the session to copy everything
s.wg.Add(1)
go func() {
defer s.wg.Done()
defer s.cancel()
_, err := io.Copy(stdout, s.stream)
if err != nil {
Expand Down Expand Up @@ -231,7 +236,8 @@ func (s *KubeSession) pipeInOut(stdout io.Writer, enableEscapeSequences bool, mo

// Wait waits for the session to finish.
func (s *KubeSession) Wait() {
<-s.ctx.Done()
// Wait for the session to copy everything into stdout
s.wg.Wait()
}

// Close sends a close request to the other end and waits it to gracefully terminate the connection.
Expand All @@ -240,7 +246,7 @@ func (s *KubeSession) Close() error {
return trace.Wrap(err)
}

<-s.ctx.Done()
s.wg.Wait()
return trace.Wrap(s.Detach())
}

Expand Down
9 changes: 6 additions & 3 deletions lib/kube/proxy/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1273,9 +1273,9 @@ func (f *Forwarder) remoteJoin(ctx *authContext, w http.ResponseWriter, req *htt
NetDialContext: sess.DialWithContext,
}

headers := req.Header
headers := http.Header{}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change meant to be in this PR? I'm not sure how it is related to the kube session data races - it doesn't appear in the data race traces, and there are no tests updated for it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a typo I introduce in #25202 but is now covered by this PR and by #26657.

Only one will deliver the change but both need them.

if impersonationHeaders {
if headers, err = auth.IdentityForwardingHeaders(req.Context(), req.Header); err != nil {
if headers, err = auth.IdentityForwardingHeaders(req.Context(), headers); err != nil {
return nil, trace.Wrap(err)
}
}
Expand All @@ -1288,6 +1288,10 @@ func (f *Forwarder) remoteJoin(ctx *authContext, w http.ResponseWriter, req *htt

wsTarget, respTarget, err := dialer.DialContext(req.Context(), url, headers)
if err != nil {
if respTarget == nil {
return nil, trace.Wrap(err)
}
defer respTarget.Body.Close()
msg, err := io.ReadAll(respTarget.Body)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -1297,7 +1301,6 @@ func (f *Forwarder) remoteJoin(ctx *authContext, w http.ResponseWriter, req *htt
if err := json.Unmarshal(msg, &obj); err != nil {
return nil, trace.Wrap(err)
}

return obj, trace.Wrap(err)
}
defer wsTarget.Close()
Expand Down