Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ func TestIntegrations(t *testing.T) {
t.Run("BPFSessionDifferentiation", suite.bind(testBPFSessionDifferentiation))
t.Run("ClientIdleConnection", suite.bind(testClientIdleConnection))
t.Run("CmdLabels", suite.bind(testCmdLabels))
t.Run("CreateAndUpdateTrustedClusters", suite.bind(testCreateAndUpdateTrustedClusters))
t.Run("ControlMaster", suite.bind(testControlMaster))
t.Run("X11Forwarding", suite.bind(testX11Forwarding))
t.Run("CustomReverseTunnel", suite.bind(testCustomReverseTunnel))
t.Run("DataTransfer", suite.bind(testDataTransfer))
t.Run("DifferentPinnedIP", suite.bind(testDifferentPinnedIP))
Expand Down Expand Up @@ -199,12 +199,12 @@ func TestIntegrations(t *testing.T) {
t.Run("TrustedClustersRoleMapChanges", suite.bind(testTrustedClustersRoleMapChanges))
t.Run("TrustedClustersWithLabels", suite.bind(testTrustedClustersWithLabels))
t.Run("TrustedClustersSkipNameValidation", suite.bind(testTrustedClustersSkipNameValidation))
t.Run("CreateAndUpdateTrustedClusters", suite.bind(testCreateAndUpdateTrustedClusters))
t.Run("TrustedTunnelNode", suite.bind(testTrustedTunnelNode))
t.Run("TwoClustersProxy", suite.bind(testTwoClustersProxy))
t.Run("TwoClustersTunnel", suite.bind(testTwoClustersTunnel))
t.Run("UUIDBasedProxy", suite.bind(testUUIDBasedProxy))
t.Run("WindowChange", suite.bind(testWindowChange))
t.Run("X11Forwarding", suite.bind(testX11Forwarding))
}

// testDifferentPinnedIP tests connection is rejected when source IP doesn't match the pinned one
Expand Down
5 changes: 1 addition & 4 deletions lib/reversetunnel/localsite.go
Original file line number Diff line number Diff line change
Expand Up @@ -471,13 +471,10 @@ func (s *localSite) dialAndForward(params reversetunnelclient.DialParams) (_ net
DataDir: s.srv.Config.DataDir,
Address: params.Address,
UseTunnel: useTunnel,
HostUUID: s.srv.ID,
ProxyUUID: s.srv.ID,
Emitter: s.srv.Config.Emitter,
ParentContext: s.srv.Context,
LockWatcher: s.srv.LockWatcher,
TargetID: params.ServerID,
TargetAddr: params.To.String(),
TargetHostname: params.Address,
TargetServer: params.TargetServer,
Clock: s.clock,
}
Expand Down
5 changes: 1 addition & 4 deletions lib/reversetunnel/remotesite.go
Original file line number Diff line number Diff line change
Expand Up @@ -913,13 +913,10 @@ func (s *remoteSite) dialAndForward(params reversetunnelclient.DialParams) (_ ne
Address: params.Address,
UseTunnel: UseTunnel(s.logger, targetConn),
FIPS: s.srv.FIPS,
HostUUID: s.srv.ID,
ProxyUUID: s.srv.ID,
Emitter: s.srv.Config.Emitter,
ParentContext: s.srv.Context,
LockWatcher: s.srv.LockWatcher,
TargetID: params.ServerID,
TargetAddr: params.To.String(),
TargetHostname: params.Address,
TargetServer: params.TargetServer,
Clock: s.clock,
}
Expand Down
2 changes: 1 addition & 1 deletion lib/srv/authhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ func (h *AuthHandlers) hostKeyCallback(hostname string, remote net.Addr, key ssh
ctx := h.c.Server.Context()

// For SubKindOpenSSHEICENode we use SSH Keys (EC2 does not support Certificates in ec2.SendSSHPublicKey).
if h.c.Server.TargetMetadata().ServerSubKind == types.SubKindOpenSSHEICENode {
if h.c.Server.GetInfo().GetSubKind() == types.SubKindOpenSSHEICENode {
return nil
}

Expand Down
9 changes: 5 additions & 4 deletions lib/srv/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ type Server interface {
GetClock() clockwork.Clock

// GetInfo returns a services.Server that represents this server.
// In the case of the Proxy forwarder, this is the node target.
GetInfo() types.Server

// UseTunnel used to determine if this node has connected to this cluster
Expand Down Expand Up @@ -190,8 +191,8 @@ type Server interface {
// support or not.
GetSELinuxEnabled() bool

// TargetMetadata returns metadata about the session target node.
TargetMetadata() apievents.ServerMetadata
// EventMetadata returns [events.ServerMetadata] for this server.
EventMetadata() apievents.ServerMetadata
}

// IdentityContext holds all identity information associated with the user
Expand Down Expand Up @@ -484,7 +485,7 @@ func NewServerContext(ctx context.Context, parent *sshutils.ConnectionContext, s
clientIdleTimeout: clientIdleTimeout,
cancelContext: cancelContext,
cancel: cancel,
ServerSubKind: srv.TargetMetadata().ServerSubKind,
ServerSubKind: srv.GetInfo().GetSubKind(),
}

child.Logger = slog.With(
Expand Down Expand Up @@ -902,7 +903,7 @@ func (c *ServerContext) reportStats(conn utils.Stater) {
Type: events.SessionDataEvent,
Code: events.SessionDataCode,
},
ServerMetadata: c.srv.TargetMetadata(),
ServerMetadata: c.srv.EventMetadata(),
SessionMetadata: c.GetSessionMetadata(),
UserMetadata: c.Identity.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
Expand Down
6 changes: 3 additions & 3 deletions lib/srv/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ func (e *localExec) transformSecureCopy() error {
Time: time.Now(),
},
UserMetadata: e.Ctx.Identity.GetUserMetadata(),
ServerMetadata: e.Ctx.GetServer().TargetMetadata(),
ServerMetadata: e.Ctx.GetServer().EventMetadata(),
Error: err.Error(),
})
return trace.Wrap(err)
Expand Down Expand Up @@ -369,7 +369,7 @@ func (e *remoteExec) Start(ctx context.Context, ch ssh.Channel) (*ExecResult, er
Time: time.Now(),
},
UserMetadata: e.ctx.Identity.GetUserMetadata(),
ServerMetadata: e.ctx.GetServer().TargetMetadata(),
ServerMetadata: e.ctx.GetServer().EventMetadata(),
Error: err.Error(),
})
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -435,7 +435,7 @@ func (e *remoteExec) PID() int {
// instead of ctx.srv.
func emitExecAuditEvent(ctx *ServerContext, cmd string, execErr error) {
// Create common fields for event.
serverMeta := ctx.GetServer().TargetMetadata()
serverMeta := ctx.GetServer().EventMetadata()
sessionMeta := ctx.GetSessionMetadata()
userMeta := ctx.Identity.GetUserMetadata()

Expand Down
2 changes: 0 additions & 2 deletions lib/srv/exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ func TestEmitExecAuditEvent(t *testing.T) {
rec, ok := scx.session.recorder.(*mockRecorder)
require.True(t, ok)

scx.GetServer().TargetMetadata()

expectedUsr, err := user.Current()
require.NoError(t, err)
expectedHostname := "testHost"
Expand Down
4 changes: 2 additions & 2 deletions lib/srv/forward/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (p *SFTPProxy) Serve() error {
Code: events.SFTPSummaryCode,
Time: time.Now(),
},
ServerMetadata: scx.GetServer().TargetMetadata(),
ServerMetadata: scx.GetServer().EventMetadata(),
SessionMetadata: scx.GetSessionMetadata(),
UserMetadata: scx.Identity.GetUserMetadata(),
ConnectionMetadata: apievents.ConnectionMetadata{
Expand Down Expand Up @@ -224,7 +224,7 @@ func (h *proxyHandlers) sendSFTPEvent(req *sftp.Request, reqErr error) {
} else if reqErr != nil {
h.logger.DebugContext(req.Context(), "failed handling SFTP request", "request", req.Method, "error", reqErr)
}
event.ServerMetadata = h.scx.GetServer().TargetMetadata()
event.ServerMetadata = h.scx.GetServer().EventMetadata()
event.SessionMetadata = h.scx.GetSessionMetadata()
event.UserMetadata = h.scx.Identity.GetUserMetadata()
event.ConnectionMetadata = apievents.ConnectionMetadata{
Expand Down
75 changes: 40 additions & 35 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"strings"
"time"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
semconv "go.opentelemetry.io/otel/semconv/v1.10.0"
Expand Down Expand Up @@ -83,8 +82,6 @@ import (
type Server struct {
logger *slog.Logger

id string

// targetConn is the TCP connection to the remote host.
targetConn net.Conn

Expand Down Expand Up @@ -155,9 +152,9 @@ type Server struct {

clock clockwork.Clock

// hostUUID is the UUID of the underlying proxy that the forwarding server
// proxyUUID is the UUID of the underlying proxy that the forwarding server
// is running in.
hostUUID string
proxyUUID string

// closeContext and closeCancel are used to signal to the outside
// world that this server is closed
Expand All @@ -174,9 +171,6 @@ type Server struct {
// of starting spans.
tracerProvider oteltrace.TracerProvider

// TODO(Joerger): Remove in favor of targetServer, which has more accurate values.
targetID, targetAddr, targetHostname string

// targetServer is the host that the connection is being established for.
targetServer types.Server
}
Expand Down Expand Up @@ -229,9 +223,9 @@ type ServerConfig struct {
// configuration.
FIPS bool

// HostUUID is the UUID of the underlying proxy that the forwarding server
// ProxyUUID is the UUID of the underlying proxy that the forwarding server
// is running in.
HostUUID string
ProxyUUID string

// Emitter is audit events emitter
Emitter events.StreamEmitter
Expand All @@ -247,9 +241,6 @@ type ServerConfig struct {
// of starting spans.
TracerProvider oteltrace.TracerProvider

// TODO(Joerger): Remove in favor of TargetServer, which has more accurate values.
TargetID, TargetAddr, TargetHostname string

// TargetServer is the host that the connection is being established for.
TargetServer types.Server
}
Expand Down Expand Up @@ -328,7 +319,6 @@ func New(c ServerConfig) (*Server, error) {
"src_addr", c.SrcAddr.String(),
"dst_addr", c.DstAddr.String(),
),
id: uuid.New().String(),
targetConn: c.TargetConn,
serverConn: utils.NewTrackingConn(serverConn),
clientConn: clientConn,
Expand All @@ -341,14 +331,11 @@ func New(c ServerConfig) (*Server, error) {
authService: c.LocalAuthClient,
dataDir: c.DataDir,
clock: c.Clock,
hostUUID: c.HostUUID,
proxyUUID: c.ProxyUUID,
StreamEmitter: c.Emitter,
parentContext: c.ParentContext,
lockWatcher: c.LockWatcher,
tracerProvider: c.TracerProvider,
targetID: c.TargetID,
targetAddr: c.TargetAddr,
targetHostname: c.TargetHostname,
targetServer: c.TargetServer,
}

Expand Down Expand Up @@ -394,16 +381,18 @@ func New(c ServerConfig) (*Server, error) {
return s, nil
}

// TargetMetadata returns metadata about the forwarding target.
func (s *Server) TargetMetadata() apievents.ServerMetadata {
// EventMetadata returns metadata about the forwarding target.
func (s *Server) EventMetadata() apievents.ServerMetadata {
serverInfo := s.GetInfo()
return apievents.ServerMetadata{
ServerVersion: teleport.Version,
ServerNamespace: s.GetNamespace(),
ServerID: s.targetID,
ServerAddr: s.targetAddr,
ServerHostname: s.targetHostname,
ForwardedBy: s.hostUUID,
ServerSubKind: s.targetServer.GetSubKind(),
ServerNamespace: serverInfo.GetNamespace(),
ServerID: serverInfo.GetName(),
ServerAddr: serverInfo.GetAddr(),
ServerLabels: serverInfo.GetAllLabels(),
ServerHostname: serverInfo.GetHostname(),
ServerSubKind: serverInfo.GetSubKind(),
ForwardedBy: s.proxyUUID,
}
}

Expand All @@ -418,15 +407,15 @@ func (s *Server) GetDataDir() string {
return s.dataDir
}

// ID returns the ID of the proxy that creates the in-memory forwarding server.
// ID returns the UUID of the server targeted by the forwarding server.
func (s *Server) ID() string {
return s.id
return s.targetServer.GetName()
}

// HostUUID is the UUID of the underlying proxy that the forwarding server
// is running in.
func (s *Server) HostUUID() string {
return s.hostUUID
return s.proxyUUID
}

// GetNamespace returns the namespace the forwarding server resides in.
Expand Down Expand Up @@ -499,19 +488,35 @@ func (s *Server) GetSELinuxEnabled() bool {
return false
}

// GetInfo returns a services.Server that represents this server.
// GetInfo returns a services.Server that represents the target server.
func (s *Server) GetInfo() types.Server {
return &types.ServerV2{
// Only set the address for non-tunnel nodes.
var addr string
if !s.targetServer.GetUseTunnel() {
addr = s.targetServer.GetAddr()
}

srv := &types.ServerV2{
Kind: types.KindNode,
SubKind: s.targetServer.GetSubKind(),
Version: types.V2,
Metadata: types.Metadata{
Name: s.ID(),
Namespace: s.GetNamespace(),
Name: s.targetServer.GetName(),
Namespace: s.targetServer.GetNamespace(),
Labels: s.targetServer.GetLabels(),
},
Spec: types.ServerSpecV2{
Addr: s.AdvertiseAddr(),
CmdLabels: types.LabelsToV2(s.targetServer.GetCmdLabels()),
Addr: addr,
Hostname: s.targetServer.GetHostname(),
UseTunnel: s.useTunnel,
Version: teleport.Version,
ProxyIDs: s.targetServer.GetProxyIDs(),
PublicAddrs: s.targetServer.GetPublicAddrs(),
},
}

return srv
}

// Dial returns the client connection created by pipeAddrConn.
Expand Down Expand Up @@ -1481,7 +1486,7 @@ func (s *Server) handleSubsystem(ctx context.Context, ch ssh.Channel, req *ssh.R
Time: time.Now(),
},
UserMetadata: serverContext.Identity.GetUserMetadata(),
ServerMetadata: serverContext.GetServer().TargetMetadata(),
ServerMetadata: serverContext.GetServer().EventMetadata(),
Error: err.Error(),
})
return trace.Wrap(err)
Expand Down
Loading
Loading