diff --git a/lib/proxy/router.go b/lib/proxy/router.go index ecd1c9bae3e1c..72823f5a1f4e3 100644 --- a/lib/proxy/router.go +++ b/lib/proxy/router.go @@ -22,7 +22,6 @@ import ( "bytes" "context" "errors" - "fmt" "log/slog" "math/rand/v2" "net" @@ -247,61 +246,49 @@ func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net. } span.AddEvent("retrieved target server") + serverID := target.GetName() + "." + clusterName + hostClusterPrincipal := host + "." + clusterName principals := []string{ host, - // Add in principal for when nodes are on leaf clusters. - host + "." + clusterName, + // Add in hostClusterPrincipal for when nodes are on leaf clusters. + hostClusterPrincipal, } - var ( - isAgentlessNode bool - serverID string - serverAddr string - proxyIDs []string - sshSigner ssh.Signer - ) + // add hostUUID.cluster to the principals if it's different from hostClusterPrincipal. + if serverID != hostClusterPrincipal { + principals = append(principals, serverID) + } - if target != nil { - proxyIDs = target.GetProxyIDs() - serverID = fmt.Sprintf("%v.%v", target.GetName(), clusterName) + serverAddr := target.GetAddr() + switch { + case serverAddr != "": + h, _, err := net.SplitHostPort(serverAddr) + if err != nil { + return nil, trace.Wrap(err) + } - // add hostUUID.cluster to the principals - principals = append(principals, serverID) + principals = append(principals, h) + case serverAddr == "" && target.GetUseTunnel(): + serverAddr = reversetunnelclient.LocalNode + } - // add ip if it exists to the principals - serverAddr = target.GetAddr() + // If the node is a registered openssh node don't set agentGetter + // so a SSH user agent will not be created when connecting to the remote node. + var sshSigner ssh.Signer + if target.IsOpenSSHNode() { + agentGetter = nil - switch { - case serverAddr != "": - h, _, err := net.SplitHostPort(serverAddr) + if target.GetSubKind() == types.SubKindOpenSSHNode { + // If the node is of SubKindOpenSSHNode, create the signer. + client, err := r.GetSiteClient(ctx, clusterName) if err != nil { return nil, trace.Wrap(err) } - - principals = append(principals, h) - case serverAddr == "" && target.GetUseTunnel(): - serverAddr = reversetunnelclient.LocalNode - } - // If the node is a registered openssh node don't set agentGetter - // so a SSH user agent will not be created when connecting to the remote node. - if target.IsOpenSSHNode() { - agentGetter = nil - isAgentlessNode = true - - if target.GetSubKind() == types.SubKindOpenSSHNode { - // If the node is of SubKindOpenSSHNode, create the signer. - client, err := r.GetSiteClient(ctx, clusterName) - if err != nil { - return nil, trace.Wrap(err) - } - sshSigner, err = signer(ctx, r.localAccessPoint, client) - if err != nil { - return nil, trace.Wrap(err) - } + sshSigner, err = signer(ctx, r.localAccessPoint, client) + if err != nil { + return nil, trace.Wrap(err) } } - } else { - return nil, trace.ConnectionProblem(errors.New("connection problem"), "direct dialing to nodes not found in inventory is not supported") } conn, err := cluster.Dial(reversetunnelclient.DialParams{ @@ -309,12 +296,11 @@ func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net. To: &utils.NetAddr{AddrNetwork: "tcp", Addr: serverAddr}, OriginalClientDstAddr: clientDstAddr, GetUserAgent: agentGetter, - IsAgentlessNode: isAgentlessNode, AgentlessSigner: sshSigner, Address: host, Principals: apiutils.Deduplicate(principals), ServerID: serverID, - ProxyIDs: proxyIDs, + ProxyIDs: target.GetProxyIDs(), ConnType: types.NodeTunnel, TargetServer: target, }) @@ -567,33 +553,37 @@ func getServerWithResolver(ctx context.Context, host, port string, cluster clust matches = filtered } - var server types.Server switch { - case strategy == types.RoutingStrategy_MOST_RECENT: - for _, m := range matches { - if server == nil || m.Expiry().After(server.Expiry()) { - server = m - } - } + case len(matches) == 1: + return matches[0], nil + case len(matches) > 1: // TODO(tross) DELETE IN V20.0.0 // NodeIsAmbiguous is included in the error message for backwards compatibility // with older nodes that expect to see that string in the error message. - return nil, trace.Wrap(teleport.ErrNodeIsAmbiguous, teleport.NodeIsAmbiguous) - case len(matches) == 1: - server = matches[0] - } + if strategy != types.RoutingStrategy_MOST_RECENT { + return nil, trace.Wrap(teleport.ErrNodeIsAmbiguous, teleport.NodeIsAmbiguous) + } - if routeMatcher.MatchesServerIDs() && server == nil { - idType := "UUID" - if aws.IsEC2NodeID(host) { - idType = "EC2" + var recentServer types.Server + for _, m := range matches { + if recentServer == nil || m.Expiry().After(recentServer.Expiry()) { + recentServer = m + } } + return recentServer, nil - return nil, trace.NotFound("unable to locate node matching %s-like target %s", idType, host) - } + default: // no matches + if routeMatcher.MatchesServerIDs() { + idType := "UUID" + if aws.IsEC2NodeID(host) { + idType = "EC2" + } + return nil, trace.NotFound("unable to locate node matching %s-like target %s", idType, host) + } - return server, nil + return nil, trace.ConnectionProblem(errors.New("connection problem"), "direct dialing to nodes not found in inventory is not supported") + } } // DialSite establishes a connection to the auth server in the provided diff --git a/lib/proxy/router_test.go b/lib/proxy/router_test.go index bb5addbef2e31..11fbb2fea7970 100644 --- a/lib/proxy/router_test.go +++ b/lib/proxy/router_test.go @@ -383,10 +383,12 @@ func TestGetServers(t *testing.T) { serverAssertion func(t *testing.T, srv types.Server) }{ { - name: "no matches for hostname", - site: testSite{cfg: &unambiguousCfg}, - host: "test", - errAssertion: require.NoError, + name: "no matches for hostname", + site: testSite{cfg: &unambiguousCfg}, + host: "test", + errAssertion: func(t require.TestingT, err error, i ...any) { + require.True(t, trace.IsConnectionProblem(err), "Expected connection error but got %v", err) + }, serverAssertion: func(t *testing.T, srv types.Server) { require.Empty(t, srv) }, @@ -795,7 +797,6 @@ func TestRouter_DialHost(t *testing.T) { require.Equal(t, agentlessSrv, params.TargetServer) require.Nil(t, params.GetUserAgent) require.NotNil(t, params.AgentlessSigner) - require.True(t, params.IsAgentlessNode) require.NotNil(t, conn) require.Contains(t, params.Principals, "host") require.Contains(t, params.Principals, "host.test") @@ -815,7 +816,6 @@ func TestRouter_DialHost(t *testing.T) { require.Equal(t, agentlessEC2ICESrv, params.TargetServer) require.Nil(t, params.GetUserAgent) require.Nil(t, params.AgentlessSigner) - require.True(t, params.IsAgentlessNode) require.NotNil(t, conn) }, }, diff --git a/lib/reversetunnel/localsite.go b/lib/reversetunnel/localsite.go index ebbc19bede6ae..5f4e375c27e53 100644 --- a/lib/reversetunnel/localsite.go +++ b/lib/reversetunnel/localsite.go @@ -237,23 +237,31 @@ func (s *localSite) DialAuthServer(params reversetunnelclient.DialParams) (net.C // shouldDialAndForward returns whether a connection should be proxied // and forwarded or not. func shouldDialAndForward(params reversetunnelclient.DialParams, recConfig types.SessionRecordingConfig) bool { + // only proxy and forward ssh connections + if params.ConnType != types.NodeTunnel { + return false + } // connection is already being tunneled, do not forward if params.FromPeerProxy { return false } // the node is an agentless node, the connection must be forwarded - if params.TargetServer != nil && params.TargetServer.IsOpenSSHNode() { + if params.TargetServer.IsOpenSSHNode() { return true } // proxy session recording mode is being used and an SSH session // is being requested, the connection must be forwarded - if params.ConnType == types.NodeTunnel && services.IsRecordAtProxy(recConfig.GetMode()) { + if services.IsRecordAtProxy(recConfig.GetMode()) { return true } return false } func (s *localSite) Dial(params reversetunnelclient.DialParams) (net.Conn, error) { + if params.TargetServer == nil && params.ConnType == types.NodeTunnel { + return nil, trace.BadParameter("target server is required for teleport nodes") + } + if params.TargetServer != nil && params.TargetServer.GetKind() == types.KindGitServer { return s.dialAndForwardGit(params) } @@ -282,7 +290,8 @@ func shouldSendSignedPROXYHeader(signer multiplexer.PROXYHeaderSigner, useTunnel } func (s *localSite) maybeSendSignedPROXYHeader(params reversetunnelclient.DialParams, conn net.Conn, useTunnel bool) error { - if !shouldSendSignedPROXYHeader(s.srv.proxySigner, useTunnel, params.IsAgentlessNode, params.From, params.OriginalClientDstAddr) { + isAgentless := params.TargetServer != nil && params.TargetServer.IsOpenSSHNode() + if !shouldSendSignedPROXYHeader(s.srv.proxySigner, useTunnel, isAgentless, params.From, params.OriginalClientDstAddr) { return nil } @@ -404,7 +413,7 @@ func (s *localSite) dialAndForwardGit(params reversetunnelclient.DialParams) (_ func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net.Conn, retErr error) { ctx := s.srv.ctx - if params.GetUserAgent == nil && !params.IsAgentlessNode { + if params.GetUserAgent == nil && !params.TargetServer.IsOpenSSHNode() { return nil, trace.BadParameter("agentless node require an agent getter") } s.logger.DebugContext(ctx, "Initiating dial and forwarding request", @@ -451,7 +460,6 @@ func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net LocalAuthClient: s.client, TargetClusterAccessPoint: s.accessPoint, UserAgent: userAgent, - IsAgentlessNode: params.IsAgentlessNode, AgentlessSigner: params.AgentlessSigner, TargetConn: targetConn, SrcAddr: params.From, @@ -473,10 +481,7 @@ func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net TargetServer: params.TargetServer, Clock: s.clock, } - // Ensure the hostname is set correctly if we have details of the target - if params.TargetServer != nil { - serverConfig.TargetHostname = params.TargetServer.GetHostname() - } + remoteServer, err := forward.New(serverConfig) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/reversetunnel/remotesite.go b/lib/reversetunnel/remotesite.go index 8f5af498084fe..2f6b211cb36dc 100644 --- a/lib/reversetunnel/remotesite.go +++ b/lib/reversetunnel/remotesite.go @@ -798,6 +798,10 @@ func (s *remoteSite) DialAuthServer(params reversetunnelclient.DialParams) (net. // located in a remote connected site, the connection goes through the // reverse proxy tunnel. func (s *remoteSite) Dial(params reversetunnelclient.DialParams) (net.Conn, error) { + if params.TargetServer == nil && params.ConnType == types.NodeTunnel { + return nil, trace.BadParameter("target server is required for teleport nodes") + } + localRecCfg, err := s.localAccessPoint.GetSessionRecordingConfig(s.ctx) if err != nil { return nil, trace.Wrap(err) @@ -831,13 +835,15 @@ func (s *remoteSite) DialTCP(params reversetunnelclient.DialParams) (net.Conn, e "target_addr", logutils.StringerAttr(params.To), ) + isAgentless := params.TargetServer != nil && params.TargetServer.IsOpenSSHNode() + conn, err := s.connThroughTunnel(&sshutils.DialReq{ Address: params.To.String(), ServerID: params.ServerID, ConnType: params.ConnType, ClientSrcAddr: stringOrEmpty(params.From), ClientDstAddr: stringOrEmpty(params.OriginalClientDstAddr), - IsAgentlessNode: params.IsAgentlessNode, + IsAgentlessNode: isAgentless, }) if err != nil { return nil, trace.Wrap(err) @@ -847,7 +853,7 @@ func (s *remoteSite) DialTCP(params reversetunnelclient.DialParams) (net.Conn, e } func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ net.Conn, retErr error) { - if params.GetUserAgent == nil && !params.IsAgentlessNode { + if params.GetUserAgent == nil && !params.TargetServer.IsOpenSSHNode() { return nil, trace.BadParameter("user agent getter is required for teleport nodes") } s.logger.DebugContext(s.ctx, "Initiating dial and forward request", @@ -882,7 +888,7 @@ func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ ne ConnType: params.ConnType, ClientSrcAddr: stringOrEmpty(params.From), ClientDstAddr: stringOrEmpty(params.OriginalClientDstAddr), - IsAgentlessNode: params.IsAgentlessNode, + IsAgentlessNode: params.TargetServer.IsOpenSSHNode(), }) if err != nil { return nil, trace.Wrap(err) @@ -895,7 +901,6 @@ func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ ne LocalAuthClient: s.localClient, TargetClusterAccessPoint: s.remoteAccessPoint, UserAgent: userAgent, - IsAgentlessNode: params.IsAgentlessNode, AgentlessSigner: params.AgentlessSigner, TargetConn: targetConn, SrcAddr: params.From, @@ -918,10 +923,6 @@ func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ ne TargetServer: params.TargetServer, Clock: s.clock, } - // Ensure the hostname is set correctly if we have details of the target - if params.TargetServer != nil { - serverConfig.TargetHostname = params.TargetServer.GetHostname() - } remoteServer, err := forward.New(serverConfig) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/reversetunnelclient/api.go b/lib/reversetunnelclient/api.go index 379c6ac558230..f2bf62d6cf6b3 100644 --- a/lib/reversetunnelclient/api.go +++ b/lib/reversetunnelclient/api.go @@ -47,10 +47,6 @@ type DialParams struct { // forwarding proxy. GetUserAgent sshagent.ClientGetter - // IsAgentlessNode indicates whether the Node is an OpenSSH Node. - // This includes Nodes whose sub kind is OpenSSH and OpenSSHEICE. - IsAgentlessNode bool - // AgentlessSigner is used for authenticating to the remote host when it is an // agentless node. AgentlessSigner ssh.Signer diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 1cd180e538827..8f47143694d1f 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -174,11 +174,10 @@ type Server struct { // of starting spans. tracerProvider oteltrace.TracerProvider + // TODO(Joerger): Remove in favor of targetServer, which has more accurate values. targetID, targetAddr, targetHostname string // targetServer is the host that the connection is being established for. - // It **MUST** only be populated when the target is a teleport ssh server - // or an agentless server. targetServer types.Server } @@ -248,16 +247,11 @@ type ServerConfig struct { // of starting spans. TracerProvider oteltrace.TracerProvider + // TODO(Joerger): Remove in favor of TargetServer, which has more accurate values. TargetID, TargetAddr, TargetHostname string // TargetServer is the host that the connection is being established for. - // It **MUST** only be populated when the target is a teleport ssh server - // or an agentless server. TargetServer types.Server - - // IsAgentlessNode indicates whether the targetServer is a Node with an OpenSSH server (no teleport agent). - // This includes Nodes whose sub kind is OpenSSH and OpenSSHEphemeralKey. - IsAgentlessNode bool } // CheckDefaults makes sure all required parameters are passed in. @@ -271,18 +265,17 @@ func (s *ServerConfig) CheckDefaults() error { if s.DataDir == "" { return trace.BadParameter("missing parameter DataDir") } - if s.IsAgentlessNode { - if s.TargetServer == nil { - return trace.BadParameter("target server is required for agentless nodes") - } - - if s.TargetServer.GetSubKind() == types.SubKindOpenSSHNode && s.AgentlessSigner == nil { + if s.TargetServer == nil { + return trace.BadParameter("target server is required") + } + if s.TargetServer.IsOpenSSHNode() { + if s.AgentlessSigner == nil { return trace.BadParameter("agentless signer is required for OpenSSH Nodes") } - } - - if s.UserAgent == nil && !s.IsAgentlessNode { - return trace.BadParameter("user agent required for teleport nodes (agentless)") + } else { + if s.UserAgent == nil { + return trace.BadParameter("user agent required for teleport nodes") + } } if s.TargetConn == nil { return trace.BadParameter("connection to target connection required") @@ -403,11 +396,6 @@ func New(c ServerConfig) (*Server, error) { // TargetMetadata returns metadata about the forwarding target. func (s *Server) TargetMetadata() apievents.ServerMetadata { - var subKind string - if s.targetServer != nil { - subKind = s.targetServer.GetSubKind() - } - return apievents.ServerMetadata{ ServerVersion: teleport.Version, ServerNamespace: s.GetNamespace(), @@ -415,7 +403,7 @@ func (s *Server) TargetMetadata() apievents.ServerMetadata { ServerAddr: s.targetAddr, ServerHostname: s.targetHostname, ForwardedBy: s.hostUUID, - ServerSubKind: subKind, + ServerSubKind: s.targetServer.GetSubKind(), } } @@ -609,7 +597,7 @@ func (s *Server) Serve() { return } - if s.targetServer != nil && s.targetServer.IsOpenSSHNode() { + if s.targetServer.IsOpenSSHNode() { // OpenSSH nodes don't support moderated sessions, send an error to // the user and gracefully fail if the user is attempting to create one. policySets := s.identityContext.UnstableSessionJoiningAccessChecker.SessionPolicySets() @@ -1025,7 +1013,7 @@ func (s *Server) checkTCPIPForwardRequest(ctx context.Context, r *ssh.Request) e } // RBAC checks are only necessary when connecting to an agentless node - if s.targetServer != nil && s.targetServer.IsOpenSSHNode() { + if s.targetServer.IsOpenSSHNode() { scx, err := srv.NewServerContext(s.Context(), s.connectionContext, s, s.identityContext) if err != nil { return err @@ -1108,7 +1096,7 @@ func (s *Server) handleDirectTCPIPRequest(ctx context.Context, ch ssh.Channel, r ch = scx.TrackActivity(ch) // RBAC checks are only necessary when connecting to an agentless node - if s.targetServer != nil && s.targetServer.IsOpenSSHNode() { + if s.targetServer.IsOpenSSHNode() { err = s.authHandlers.CheckPortForward(scx.DstAddr, scx, decisionpb.SSHPortForwardMode_SSH_PORT_FORWARD_MODE_LOCAL) if err != nil { s.stderrWrite(ctx, ch, err.Error()) diff --git a/lib/srv/forward/sshserver_test.go b/lib/srv/forward/sshserver_test.go index fa7e3ad67b266..d8cb2e0cec7e3 100644 --- a/lib/srv/forward/sshserver_test.go +++ b/lib/srv/forward/sshserver_test.go @@ -31,6 +31,7 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" apisshutils "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/fixtures" @@ -257,6 +258,7 @@ func TestCheckTCPIPForward(t *testing.T) { s := Server{ logger: logtest.NewLogger(), identityContext: srv.IdentityContext{Login: tt.login}, + targetServer: &types.ServerV2{}, } err := s.checkTCPIPForwardRequest(context.Background(), &ssh.Request{