Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions api/client/proxy/transport/transportv1/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,25 +112,29 @@ func (c clusterStream) Close() error {
// the hostport. If a keyring is provided then it will be forwarded to the remote instance.
// The src address will be used as the LocalAddr of the returned [net.Conn].
func (c *Client) DialHost(ctx context.Context, hostport, cluster string, src net.Addr, keyring agent.ExtendedAgent) (net.Conn, *transportv1pb.ClusterDetails, error) {
ctx, cancel := context.WithCancel(ctx)
stream, err := c.clt.ProxySSH(ctx)
if err != nil {
cancel()
return nil, nil, trace.Wrap(err, "unable to establish proxy stream")
}

if err := stream.Send(&transportv1pb.ProxySSHRequest{DialTarget: &transportv1pb.TargetHost{
HostPort: hostport,
Cluster: cluster,
}}); err != nil {
cancel()
return nil, nil, trace.Wrap(err, "failed to send dial target request")
}

resp, err := stream.Recv()
if err != nil {
cancel()
return nil, nil, trace.Wrap(err, "failed to receive cluster details response")
}

// create streams for ssh and agent protocol
sshStream, agentStream := newSSHStreams(stream)
sshStream, agentStream := newSSHStreams(stream, cancel)

// create a reader writer for agent protocol
agentRW, err := streamutils.NewReadWriter(agentStream)
Expand Down Expand Up @@ -210,9 +214,10 @@ type sshStream struct {
closedC chan struct{}
wLock *sync.Mutex
stream transportv1pb.TransportService_ProxySSHClient
cancel context.CancelFunc
}

func newSSHStreams(stream transportv1pb.TransportService_ProxySSHClient) (ssh *sshStream, agent *sshStream) {
func newSSHStreams(stream transportv1pb.TransportService_ProxySSHClient, cancel context.CancelFunc) (ssh *sshStream, agent *sshStream) {
wLock := &sync.Mutex{}
closedC := make(chan struct{})

Expand All @@ -225,6 +230,7 @@ func newSSHStreams(stream transportv1pb.TransportService_ProxySSHClient) (ssh *s
},
wLock: wLock,
closedC: closedC,
cancel: cancel,
}

agent = &sshStream{
Expand All @@ -236,6 +242,7 @@ func newSSHStreams(stream transportv1pb.TransportService_ProxySSHClient) (ssh *s
},
wLock: wLock,
closedC: closedC,
cancel: cancel,
}

return ssh, agent
Expand Down Expand Up @@ -265,6 +272,7 @@ func (s *sshStream) Send(frame []byte) error {
}

func (s *sshStream) Close() error {
s.cancel()
// grab lock to prevent any sends from occurring
s.wLock.Lock()
defer s.wLock.Unlock()
Expand Down