Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 78 additions & 2 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func TestIntegrations(t *testing.T) {
t.Run("ClientIdleConnection", suite.bind(testClientIdleConnection))
t.Run("CmdLabels", suite.bind(testCmdLabels))
t.Run("ControlMaster", suite.bind(testControlMaster))
t.Run("X11Forwarding", suite.bind(testX11Forwarding))
t.Run("CreateAndUpdateTrustedClusters", suite.bind(testCreateAndUpdateTrustedClusters))
t.Run("CustomReverseTunnel", suite.bind(testCustomReverseTunnel))
t.Run("DataTransfer", suite.bind(testDataTransfer))
t.Run("DifferentPinnedIP", suite.bind(testDifferentPinnedIP))
Expand Down Expand Up @@ -181,6 +181,7 @@ func TestIntegrations(t *testing.T) {
t.Run("PAM", suite.bind(testPAM))
t.Run("PortForwarding", suite.bind(testPortForwarding))
t.Run("ProxyHostKeyCheck", suite.bind(testProxyHostKeyCheck))
t.Run("RecordingModesSessionTrackers", suite.bind(testRecordingModesSessionTrackers))
t.Run("ReverseTunnelCollapse", suite.bind(testReverseTunnelCollapse))
t.Run("RotateRollback", suite.bind(testRotateRollback))
t.Run("RotateSuccess", suite.bind(testRotateSuccess))
Expand All @@ -197,12 +198,12 @@ func TestIntegrations(t *testing.T) {
t.Run("TrustedClustersRoleMapChanges", suite.bind(testTrustedClustersRoleMapChanges))
t.Run("TrustedClustersWithLabels", suite.bind(testTrustedClustersWithLabels))
t.Run("TrustedClustersSkipNameValidation", suite.bind(testTrustedClustersSkipNameValidation))
t.Run("CreateAndUpdateTrustedClusters", suite.bind(testCreateAndUpdateTrustedClusters))
t.Run("TrustedTunnelNode", suite.bind(testTrustedTunnelNode))
t.Run("TwoClustersProxy", suite.bind(testTwoClustersProxy))
t.Run("TwoClustersTunnel", suite.bind(testTwoClustersTunnel))
t.Run("UUIDBasedProxy", suite.bind(testUUIDBasedProxy))
t.Run("WindowChange", suite.bind(testWindowChange))
t.Run("X11Forwarding", suite.bind(testX11Forwarding))
}

// testDifferentPinnedIP tests connection is rejected when source IP doesn't match the pinned one
Expand Down Expand Up @@ -1036,6 +1037,81 @@ func testSessionRecordingModes(t *testing.T, suite *integrationTestSuite) {
}
}

func testRecordingModesSessionTrackers(t *testing.T, suite *integrationTestSuite) {
ctx := context.Background()

cfg := suite.defaultServiceConfig()
cfg.Auth.Enabled = true
cfg.Proxy.DisableWebService = true
cfg.Proxy.DisableWebInterface = true
cfg.Proxy.Enabled = true
cfg.SSH.Enabled = true

teleport := suite.NewTeleportWithConfig(t, nil, nil, cfg)
defer teleport.StopAll()

// startSession starts an interactive session, users must terminate the
// session by typing "exit" in the terminal.
startSession := func(username string) (*Terminal, chan error) {
term := NewTerminal(250)
errCh := make(chan error)

go func() {
cl, err := teleport.NewClient(helpers.ClientConfig{
Login: username,
Cluster: helpers.Site,
Host: Host,
})
if err != nil {
errCh <- trace.Wrap(err)
return
}
cl.Stdout = term
cl.Stdin = term

ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
errCh <- cl.SSH(ctx, nil)
}()

return term, errCh
}

auth := teleport.Process.GetAuthServer()
for _, mode := range []string{types.RecordAtNode, types.RecordAtProxy} {
t.Run(mode, func(t *testing.T) {
rc := types.DefaultSessionRecordingConfig()
rc.SetMode(mode)

_, err := auth.UpsertSessionRecordingConfig(ctx, rc)
require.NoError(t, err)

// Start session.
term, errCh := startSession(suite.Me.Username)

// Validate that the session tracker exists and contains
// the correct target address.
var sessionID string
require.EventuallyWithT(t, func(t *assert.CollectT) {
trackers, err := auth.GetActiveSessionTrackers(ctx)
assert.NoError(t, err)
if !assert.Len(t, trackers, 1) {
return
}
assert.Equal(t, helpers.HostID, trackers[0].GetAddress())
sessionID = trackers[0].GetSessionID()
}, 30*time.Second, 100*time.Millisecond)

// Wait for the session to terminate without error.
term.Type("exit\n\r")
require.NoError(t, waitForError(errCh, 30*time.Second))

// Manually clean up the tracker for the session to prevent
// it leaking into the next test case.
require.NoError(t, auth.RemoveSessionTracker(ctx, sessionID))
})
}
}
func testLeafProxySessionRecording(t *testing.T, suite *integrationTestSuite) {
tests := []struct {
rootRecordingMode string
Expand Down
78 changes: 32 additions & 46 deletions lib/proxy/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"bytes"
"context"
"errors"
"fmt"
"math/rand/v2"
"net"
"os"
Expand Down Expand Up @@ -229,69 +228,52 @@ func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net.
}
span.AddEvent("retrieved target server")

serverID := target.GetName() + "." + clusterName
principals := []string{
host,
// Add in principal for when nodes are on leaf clusters.
// Required when nodes are in leaf clusters.
host + "." + clusterName,
}
proxyIDs := target.GetProxyIDs()

var (
isAgentlessNode bool
serverID string
serverAddr string
proxyIDs []string
sshSigner ssh.Signer
)

if target != nil {
proxyIDs = target.GetProxyIDs()
serverID = fmt.Sprintf("%v.%v", target.GetName(), clusterName)
// add ip if it exists to the principals
serverAddr := target.GetAddr()
switch {
case serverAddr != "":
h, _, err := net.SplitHostPort(serverAddr)
if err != nil {
return nil, trace.Wrap(err)
}

// add hostUUID.cluster to the principals
principals = append(principals, serverID)
principals = append(principals, h)
case serverAddr == "" && target.GetUseTunnel():
serverAddr = reversetunnelclient.LocalNode
}

// add ip if it exists to the principals
serverAddr = target.GetAddr()
// If the node is a registered openssh node don't set agentGetter
// so a SSH user agent will not be created when connecting to the remote node.
var sshSigner ssh.Signer
if target.IsOpenSSHNode() {
agentGetter = nil

switch {
case serverAddr != "":
h, _, err := net.SplitHostPort(serverAddr)
if target.GetSubKind() == types.SubKindOpenSSHNode {
// If the node is of SubKindOpenSSHNode, create the signer.
client, err := r.GetSiteClient(ctx, clusterName)
if err != nil {
return nil, trace.Wrap(err)
}

principals = append(principals, h)
case serverAddr == "" && target.GetUseTunnel():
serverAddr = reversetunnelclient.LocalNode
}
// If the node is a registered openssh node don't set agentGetter
// so a SSH user agent will not be created when connecting to the remote node.
if target.IsOpenSSHNode() {
agentGetter = nil
isAgentlessNode = true

if target.GetSubKind() == types.SubKindOpenSSHNode {
// If the node is of SubKindOpenSSHNode, create the signer.
client, err := r.GetSiteClient(ctx, clusterName)
if err != nil {
return nil, trace.Wrap(err)
}
sshSigner, err = signer(ctx, r.localAccessPoint, client)
if err != nil {
return nil, trace.Wrap(err)
}
sshSigner, err = signer(ctx, r.localAccessPoint, client)
if err != nil {
return nil, trace.Wrap(err)
}
}
} else {
return nil, trace.ConnectionProblem(errors.New("connection problem"), "direct dialing to nodes not found in inventory is not supported")
}

conn, err := site.Dial(reversetunnelclient.DialParams{
From: clientSrcAddr,
To: &utils.NetAddr{AddrNetwork: "tcp", Addr: serverAddr},
OriginalClientDstAddr: clientDstAddr,
GetUserAgent: agentGetter,
IsAgentlessNode: isAgentlessNode,
AgentlessSigner: sshSigner,
Address: host,
Principals: apiutils.Deduplicate(principals),
Expand Down Expand Up @@ -519,7 +501,11 @@ func getServerWithResolver(ctx context.Context, host, port string, site site, re
server = matches[0]
}

if routeMatcher.MatchesServerIDs() && server == nil {
if server != nil {
return server, nil
}

if routeMatcher.MatchesServerIDs() {
idType := "UUID"
if aws.IsEC2NodeID(host) {
idType = "EC2"
Expand All @@ -528,7 +514,7 @@ func getServerWithResolver(ctx context.Context, host, port string, site site, re
return nil, trace.NotFound("unable to locate node matching %s-like target %s", idType, host)
}

return server, nil
return nil, trace.ConnectionProblem(errors.New("connection problem"), "direct dialing to nodes not found in inventory is not supported")
}

// DialSite establishes a connection to the auth server in the provided
Expand Down
4 changes: 1 addition & 3 deletions lib/proxy/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ func TestGetServers(t *testing.T) {
name: "no matches for hostname",
site: testSite{cfg: &unambiguousCfg},
host: "test",
errAssertion: require.NoError,
errAssertion: require.Error,
serverAssertion: func(t *testing.T, srv types.Server) {
require.Empty(t, srv)
},
Expand Down Expand Up @@ -799,7 +799,6 @@ func TestRouter_DialHost(t *testing.T) {
require.Equal(t, agentlessSrv, params.TargetServer)
require.Nil(t, params.GetUserAgent)
require.NotNil(t, params.AgentlessSigner)
require.True(t, params.IsAgentlessNode)
require.NotNil(t, conn)
require.Contains(t, params.Principals, "host")
require.Contains(t, params.Principals, "host.test")
Expand All @@ -819,7 +818,6 @@ func TestRouter_DialHost(t *testing.T) {
require.Equal(t, agentlessEC2ICESrv, params.TargetServer)
require.Nil(t, params.GetUserAgent)
require.Nil(t, params.AgentlessSigner)
require.True(t, params.IsAgentlessNode)
require.NotNil(t, conn)
},
},
Expand Down
16 changes: 7 additions & 9 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,10 @@ func shouldDialAndForward(params reversetunnelclient.DialParams, recConfig types
}

func (s *localSite) Dial(params reversetunnelclient.DialParams) (net.Conn, error) {
if params.TargetServer == nil && params.ConnType == types.NodeTunnel {
return nil, trace.BadParameter("target server is required for teleport nodes")
}

if params.TargetServer != nil && params.TargetServer.GetKind() == types.KindGitServer {
return s.dialAndForwardGit(params)
}
Expand Down Expand Up @@ -282,7 +286,8 @@ func shouldSendSignedPROXYHeader(signer multiplexer.PROXYHeaderSigner, useTunnel
}

func (s *localSite) maybeSendSignedPROXYHeader(params reversetunnelclient.DialParams, conn net.Conn, useTunnel bool) error {
if !shouldSendSignedPROXYHeader(s.srv.proxySigner, useTunnel, params.IsAgentlessNode, params.From, params.OriginalClientDstAddr) {
isAgentless := params.ConnType == types.NodeTunnel && params.TargetServer != nil && params.TargetServer.IsOpenSSHNode()
if !shouldSendSignedPROXYHeader(s.srv.proxySigner, useTunnel, isAgentless, params.From, params.OriginalClientDstAddr) {
return nil
}

Expand Down Expand Up @@ -404,7 +409,7 @@ func (s *localSite) dialAndForwardGit(params reversetunnelclient.DialParams) (_
func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net.Conn, retErr error) {
ctx := s.srv.ctx

if params.GetUserAgent == nil && !params.IsAgentlessNode {
if params.GetUserAgent == nil && !params.TargetServer.IsOpenSSHNode() {
return nil, trace.BadParameter("agentless node require an agent getter")
}
s.logger.DebugContext(ctx, "Initiating dial and forwarding request",
Expand Down Expand Up @@ -451,7 +456,6 @@ func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net
LocalAuthClient: s.client,
TargetClusterAccessPoint: s.accessPoint,
UserAgent: userAgent,
IsAgentlessNode: params.IsAgentlessNode,
AgentlessSigner: params.AgentlessSigner,
TargetConn: targetConn,
SrcAddr: params.From,
Expand All @@ -467,16 +471,10 @@ func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net
Emitter: s.srv.Config.Emitter,
ParentContext: s.srv.Context,
LockWatcher: s.srv.LockWatcher,
TargetID: params.ServerID,
TargetAddr: params.To.String(),
TargetHostname: params.Address,
TargetServer: params.TargetServer,
Clock: s.clock,
}
// Ensure the hostname is set correctly if we have details of the target
if params.TargetServer != nil {
serverConfig.TargetHostname = params.TargetServer.GetHostname()
}
remoteServer, err := forward.New(serverConfig)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
19 changes: 9 additions & 10 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,10 @@ func (s *remoteSite) DialAuthServer(params reversetunnelclient.DialParams) (net.
// located in a remote connected site, the connection goes through the
// reverse proxy tunnel.
func (s *remoteSite) Dial(params reversetunnelclient.DialParams) (net.Conn, error) {
if params.TargetServer == nil && params.ConnType == types.NodeTunnel {
return nil, trace.BadParameter("target server is required for teleport nodes")
}

localRecCfg, err := s.localAccessPoint.GetSessionRecordingConfig(s.ctx)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -831,13 +835,15 @@ func (s *remoteSite) DialTCP(params reversetunnelclient.DialParams) (net.Conn, e
"target_addr", logutils.StringerAttr(params.To),
)

isAgentless := params.ConnType == types.NodeTunnel && params.TargetServer != nil && params.TargetServer.IsOpenSSHNode()

conn, err := s.connThroughTunnel(&sshutils.DialReq{
Address: params.To.String(),
ServerID: params.ServerID,
ConnType: params.ConnType,
ClientSrcAddr: stringOrEmpty(params.From),
ClientDstAddr: stringOrEmpty(params.OriginalClientDstAddr),
IsAgentlessNode: params.IsAgentlessNode,
IsAgentlessNode: isAgentless,
})
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -847,7 +853,7 @@ func (s *remoteSite) DialTCP(params reversetunnelclient.DialParams) (net.Conn, e
}

func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ net.Conn, retErr error) {
if params.GetUserAgent == nil && !params.IsAgentlessNode {
if params.GetUserAgent == nil && !params.TargetServer.IsOpenSSHNode() {
return nil, trace.BadParameter("user agent getter is required for teleport nodes")
}
s.logger.DebugContext(s.ctx, "Initiating dial and forward request",
Expand Down Expand Up @@ -882,7 +888,7 @@ func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ ne
ConnType: params.ConnType,
ClientSrcAddr: stringOrEmpty(params.From),
ClientDstAddr: stringOrEmpty(params.OriginalClientDstAddr),
IsAgentlessNode: params.IsAgentlessNode,
IsAgentlessNode: params.TargetServer.IsOpenSSHNode(),
})
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -895,7 +901,6 @@ func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ ne
LocalAuthClient: s.localClient,
TargetClusterAccessPoint: s.remoteAccessPoint,
UserAgent: userAgent,
IsAgentlessNode: params.IsAgentlessNode,
AgentlessSigner: params.AgentlessSigner,
TargetConn: targetConn,
SrcAddr: params.From,
Expand All @@ -912,16 +917,10 @@ func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ ne
Emitter: s.srv.Config.Emitter,
ParentContext: s.srv.Context,
LockWatcher: s.srv.LockWatcher,
TargetID: params.ServerID,
TargetAddr: params.To.String(),
TargetHostname: params.Address,
TargetServer: params.TargetServer,
Clock: s.clock,
}
// Ensure the hostname is set correctly if we have details of the target
if params.TargetServer != nil {
serverConfig.TargetHostname = params.TargetServer.GetHostname()
}
remoteServer, err := forward.New(serverConfig)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
Loading
Loading