Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 53 additions & 63 deletions lib/proxy/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"bytes"
"context"
"errors"
"fmt"
"log/slog"
"math/rand/v2"
"net"
Expand Down Expand Up @@ -247,74 +246,61 @@ func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net.
}
span.AddEvent("retrieved target server")

serverID := target.GetName() + "." + clusterName
hostClusterPrincipal := host + "." + clusterName
principals := []string{
host,
// Add in principal for when nodes are on leaf clusters.
host + "." + clusterName,
// Add in hostClusterPrincipal for when nodes are on leaf clusters.
hostClusterPrincipal,
}

var (
isAgentlessNode bool
serverID string
serverAddr string
proxyIDs []string
sshSigner ssh.Signer
)
// add hostUUID.cluster to the principals if it's different from hostClusterPrincipal.
if serverID != hostClusterPrincipal {
principals = append(principals, serverID)
}

if target != nil {
proxyIDs = target.GetProxyIDs()
serverID = fmt.Sprintf("%v.%v", target.GetName(), clusterName)
serverAddr := target.GetAddr()
switch {
case serverAddr != "":
h, _, err := net.SplitHostPort(serverAddr)
if err != nil {
return nil, trace.Wrap(err)
}

// add hostUUID.cluster to the principals
principals = append(principals, serverID)
principals = append(principals, h)
case serverAddr == "" && target.GetUseTunnel():
serverAddr = reversetunnelclient.LocalNode
}

// add ip if it exists to the principals
serverAddr = target.GetAddr()
// If the node is a registered openssh node don't set agentGetter
// so a SSH user agent will not be created when connecting to the remote node.
var sshSigner ssh.Signer
if target.IsOpenSSHNode() {
agentGetter = nil

switch {
case serverAddr != "":
h, _, err := net.SplitHostPort(serverAddr)
if target.GetSubKind() == types.SubKindOpenSSHNode {
// If the node is of SubKindOpenSSHNode, create the signer.
client, err := r.GetSiteClient(ctx, clusterName)
if err != nil {
return nil, trace.Wrap(err)
}

principals = append(principals, h)
case serverAddr == "" && target.GetUseTunnel():
serverAddr = reversetunnelclient.LocalNode
}
// If the node is a registered openssh node don't set agentGetter
// so a SSH user agent will not be created when connecting to the remote node.
if target.IsOpenSSHNode() {
agentGetter = nil
isAgentlessNode = true

if target.GetSubKind() == types.SubKindOpenSSHNode {
// If the node is of SubKindOpenSSHNode, create the signer.
client, err := r.GetSiteClient(ctx, clusterName)
if err != nil {
return nil, trace.Wrap(err)
}
sshSigner, err = signer(ctx, r.localAccessPoint, client)
if err != nil {
return nil, trace.Wrap(err)
}
sshSigner, err = signer(ctx, r.localAccessPoint, client)
if err != nil {
return nil, trace.Wrap(err)
}
}
} else {
return nil, trace.ConnectionProblem(errors.New("connection problem"), "direct dialing to nodes not found in inventory is not supported")
}

conn, err := cluster.Dial(reversetunnelclient.DialParams{
From: clientSrcAddr,
To: &utils.NetAddr{AddrNetwork: "tcp", Addr: serverAddr},
OriginalClientDstAddr: clientDstAddr,
GetUserAgent: agentGetter,
IsAgentlessNode: isAgentlessNode,
AgentlessSigner: sshSigner,
Address: host,
Principals: apiutils.Deduplicate(principals),
ServerID: serverID,
ProxyIDs: proxyIDs,
ProxyIDs: target.GetProxyIDs(),
ConnType: types.NodeTunnel,
TargetServer: target,
})
Expand Down Expand Up @@ -567,33 +553,37 @@ func getServerWithResolver(ctx context.Context, host, port string, cluster clust
matches = filtered
}

var server types.Server
switch {
case strategy == types.RoutingStrategy_MOST_RECENT:
for _, m := range matches {
if server == nil || m.Expiry().After(server.Expiry()) {
server = m
}
}
case len(matches) == 1:
return matches[0], nil

case len(matches) > 1:
// TODO(tross) DELETE IN V20.0.0
// NodeIsAmbiguous is included in the error message for backwards compatibility
// with older nodes that expect to see that string in the error message.
return nil, trace.Wrap(teleport.ErrNodeIsAmbiguous, teleport.NodeIsAmbiguous)
case len(matches) == 1:
server = matches[0]
}
if strategy != types.RoutingStrategy_MOST_RECENT {
return nil, trace.Wrap(teleport.ErrNodeIsAmbiguous, teleport.NodeIsAmbiguous)
}

if routeMatcher.MatchesServerIDs() && server == nil {
idType := "UUID"
if aws.IsEC2NodeID(host) {
idType = "EC2"
var recentServer types.Server
for _, m := range matches {
if recentServer == nil || m.Expiry().After(recentServer.Expiry()) {
recentServer = m
}
}
return recentServer, nil

return nil, trace.NotFound("unable to locate node matching %s-like target %s", idType, host)
}
default: // no matches
if routeMatcher.MatchesServerIDs() {
idType := "UUID"
if aws.IsEC2NodeID(host) {
idType = "EC2"
}
return nil, trace.NotFound("unable to locate node matching %s-like target %s", idType, host)
}

return server, nil
return nil, trace.ConnectionProblem(errors.New("connection problem"), "direct dialing to nodes not found in inventory is not supported")
}
}

// DialSite establishes a connection to the auth server in the provided
Expand Down
12 changes: 6 additions & 6 deletions lib/proxy/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,12 @@ func TestGetServers(t *testing.T) {
serverAssertion func(t *testing.T, srv types.Server)
}{
{
name: "no matches for hostname",
site: testSite{cfg: &unambiguousCfg},
host: "test",
errAssertion: require.NoError,
name: "no matches for hostname",
site: testSite{cfg: &unambiguousCfg},
host: "test",
errAssertion: func(t require.TestingT, err error, i ...any) {
require.True(t, trace.IsConnectionProblem(err), "Expected connection error but got %v", err)
},
serverAssertion: func(t *testing.T, srv types.Server) {
require.Empty(t, srv)
},
Expand Down Expand Up @@ -795,7 +797,6 @@ func TestRouter_DialHost(t *testing.T) {
require.Equal(t, agentlessSrv, params.TargetServer)
require.Nil(t, params.GetUserAgent)
require.NotNil(t, params.AgentlessSigner)
require.True(t, params.IsAgentlessNode)
require.NotNil(t, conn)
require.Contains(t, params.Principals, "host")
require.Contains(t, params.Principals, "host.test")
Expand All @@ -815,7 +816,6 @@ func TestRouter_DialHost(t *testing.T) {
require.Equal(t, agentlessEC2ICESrv, params.TargetServer)
require.Nil(t, params.GetUserAgent)
require.Nil(t, params.AgentlessSigner)
require.True(t, params.IsAgentlessNode)
require.NotNil(t, conn)
},
},
Expand Down
23 changes: 14 additions & 9 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,23 +237,31 @@ func (s *localSite) DialAuthServer(params reversetunnelclient.DialParams) (net.C
// shouldDialAndForward returns whether a connection should be proxied
// and forwarded or not.
func shouldDialAndForward(params reversetunnelclient.DialParams, recConfig types.SessionRecordingConfig) bool {
// only proxy and forward ssh connections
if params.ConnType != types.NodeTunnel {
return false
}
// connection is already being tunneled, do not forward
if params.FromPeerProxy {
return false
}
// the node is an agentless node, the connection must be forwarded
if params.TargetServer != nil && params.TargetServer.IsOpenSSHNode() {
if params.TargetServer.IsOpenSSHNode() {
return true
}
// proxy session recording mode is being used and an SSH session
// is being requested, the connection must be forwarded
if params.ConnType == types.NodeTunnel && services.IsRecordAtProxy(recConfig.GetMode()) {
if services.IsRecordAtProxy(recConfig.GetMode()) {
return true
}
return false
}

func (s *localSite) Dial(params reversetunnelclient.DialParams) (net.Conn, error) {
if params.TargetServer == nil && params.ConnType == types.NodeTunnel {
return nil, trace.BadParameter("target server is required for teleport nodes")
}

if params.TargetServer != nil && params.TargetServer.GetKind() == types.KindGitServer {
return s.dialAndForwardGit(params)
}
Expand Down Expand Up @@ -282,7 +290,8 @@ func shouldSendSignedPROXYHeader(signer multiplexer.PROXYHeaderSigner, useTunnel
}

func (s *localSite) maybeSendSignedPROXYHeader(params reversetunnelclient.DialParams, conn net.Conn, useTunnel bool) error {
if !shouldSendSignedPROXYHeader(s.srv.proxySigner, useTunnel, params.IsAgentlessNode, params.From, params.OriginalClientDstAddr) {
isAgentless := params.TargetServer != nil && params.TargetServer.IsOpenSSHNode()
if !shouldSendSignedPROXYHeader(s.srv.proxySigner, useTunnel, isAgentless, params.From, params.OriginalClientDstAddr) {
return nil
}

Expand Down Expand Up @@ -404,7 +413,7 @@ func (s *localSite) dialAndForwardGit(params reversetunnelclient.DialParams) (_
func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net.Conn, retErr error) {
ctx := s.srv.ctx

if params.GetUserAgent == nil && !params.IsAgentlessNode {
if params.GetUserAgent == nil && !params.TargetServer.IsOpenSSHNode() {
return nil, trace.BadParameter("agentless node require an agent getter")
}
s.logger.DebugContext(ctx, "Initiating dial and forwarding request",
Expand Down Expand Up @@ -451,7 +460,6 @@ func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net
LocalAuthClient: s.client,
TargetClusterAccessPoint: s.accessPoint,
UserAgent: userAgent,
IsAgentlessNode: params.IsAgentlessNode,
AgentlessSigner: params.AgentlessSigner,
TargetConn: targetConn,
SrcAddr: params.From,
Expand All @@ -473,10 +481,7 @@ func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net
TargetServer: params.TargetServer,
Clock: s.clock,
}
// Ensure the hostname is set correctly if we have details of the target
if params.TargetServer != nil {
serverConfig.TargetHostname = params.TargetServer.GetHostname()
}

remoteServer, err := forward.New(serverConfig)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
17 changes: 9 additions & 8 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,10 @@ func (s *remoteSite) DialAuthServer(params reversetunnelclient.DialParams) (net.
// located in a remote connected site, the connection goes through the
// reverse proxy tunnel.
func (s *remoteSite) Dial(params reversetunnelclient.DialParams) (net.Conn, error) {
if params.TargetServer == nil && params.ConnType == types.NodeTunnel {
return nil, trace.BadParameter("target server is required for teleport nodes")
}

localRecCfg, err := s.localAccessPoint.GetSessionRecordingConfig(s.ctx)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -831,13 +835,15 @@ func (s *remoteSite) DialTCP(params reversetunnelclient.DialParams) (net.Conn, e
"target_addr", logutils.StringerAttr(params.To),
)

isAgentless := params.TargetServer != nil && params.TargetServer.IsOpenSSHNode()

conn, err := s.connThroughTunnel(&sshutils.DialReq{
Address: params.To.String(),
ServerID: params.ServerID,
ConnType: params.ConnType,
ClientSrcAddr: stringOrEmpty(params.From),
ClientDstAddr: stringOrEmpty(params.OriginalClientDstAddr),
IsAgentlessNode: params.IsAgentlessNode,
IsAgentlessNode: isAgentless,
})
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -847,7 +853,7 @@ func (s *remoteSite) DialTCP(params reversetunnelclient.DialParams) (net.Conn, e
}

func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ net.Conn, retErr error) {
if params.GetUserAgent == nil && !params.IsAgentlessNode {
if params.GetUserAgent == nil && !params.TargetServer.IsOpenSSHNode() {
return nil, trace.BadParameter("user agent getter is required for teleport nodes")
}
s.logger.DebugContext(s.ctx, "Initiating dial and forward request",
Expand Down Expand Up @@ -882,7 +888,7 @@ func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ ne
ConnType: params.ConnType,
ClientSrcAddr: stringOrEmpty(params.From),
ClientDstAddr: stringOrEmpty(params.OriginalClientDstAddr),
IsAgentlessNode: params.IsAgentlessNode,
IsAgentlessNode: params.TargetServer.IsOpenSSHNode(),
})
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -895,7 +901,6 @@ func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ ne
LocalAuthClient: s.localClient,
TargetClusterAccessPoint: s.remoteAccessPoint,
UserAgent: userAgent,
IsAgentlessNode: params.IsAgentlessNode,
AgentlessSigner: params.AgentlessSigner,
TargetConn: targetConn,
SrcAddr: params.From,
Expand All @@ -918,10 +923,6 @@ func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ ne
TargetServer: params.TargetServer,
Clock: s.clock,
}
// Ensure the hostname is set correctly if we have details of the target
if params.TargetServer != nil {
serverConfig.TargetHostname = params.TargetServer.GetHostname()
}
remoteServer, err := forward.New(serverConfig)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
4 changes: 0 additions & 4 deletions lib/reversetunnelclient/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@ type DialParams struct {
// forwarding proxy.
GetUserAgent sshagent.ClientGetter

// IsAgentlessNode indicates whether the Node is an OpenSSH Node.
// This includes Nodes whose sub kind is OpenSSH and OpenSSHEICE.
IsAgentlessNode bool

// AgentlessSigner is used for authenticating to the remote host when it is an
// agentless node.
AgentlessSigner ssh.Signer
Expand Down
Loading
Loading