Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3060,7 +3060,12 @@ func (h *Handler) siteNodeConnect(
keepAliveInterval = netConfig.GetKeepAliveInterval()
}

terminalConfig := TerminalHandlerConfig{
nw, err := site.NodeWatcher()
if err != nil {
return nil, trace.Wrap(err)
}

term, err := NewTerminal(ctx, TerminalHandlerConfig{
Term: req.Term,
SessionCtx: sessionCtx,
UserAuthClient: clt,
Expand All @@ -3078,9 +3083,18 @@ func (h *Handler) siteNodeConnect(
Tracker: tracker,
PresenceChecker: h.cfg.PresenceChecker,
WebsocketConn: ws,
}
HostNameResolver: func(serverID string) (string, error) {
matches := nw.GetNodes(r.Context(), func(n services.Node) bool {
return n.GetName() == serverID
})

if len(matches) != 1 {
return "", trace.NotFound("unable to resolve hostname for server %s", serverID)
}

term, err := NewTerminal(ctx, terminalConfig)
return matches[0].GetHostname(), nil
},
})
if err != nil {
h.log.WithError(err).Error("Unable to create terminal.")
return nil, trace.Wrap(err)
Expand Down
7 changes: 7 additions & 0 deletions lib/web/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,9 @@ func (h *Handler) executeCommand(
LocalAccessPoint: h.auth.accessPoint,
mfaFuncCache: mfaCacheFn,
buffer: buffer,
HostNameResolver: func(serverID string) (string, error) {
return serverID, nil
},
}

handler, err := newCommandHandler(ctx, commandHandlerConfig)
Expand Down Expand Up @@ -467,6 +470,7 @@ func newCommandHandler(ctx context.Context, cfg CommandHandlerConfig) (*commandH
router: cfg.Router,
localAccessPoint: cfg.LocalAccessPoint,
tracer: cfg.tracer,
resolver: cfg.HostNameResolver,
},
mfaAuthCache: cfg.mfaFuncCache,
buffer: cfg.buffer,
Expand Down Expand Up @@ -499,6 +503,9 @@ type CommandHandlerConfig struct {
// Anything requests that should be made on behalf of the user should
// use [UserAuthClient].
LocalAccessPoint localAccessPoint
// HostNameResolver allows the hostname to be determined from a server UUID
// so that a friendly name can be displayed in the console tab.
HostNameResolver func(serverID string) (hostname string, err error)
// tracer is used to create spans
tracer oteltrace.Tracer
// mfaFuncCache is used to cache the MFA auth method
Expand Down
12 changes: 9 additions & 3 deletions lib/web/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ func NewTerminal(ctx context.Context, cfg TerminalHandlerConfig) (*TerminalHandl
interactiveCommand: cfg.InteractiveCommand,
router: cfg.Router,
tracer: cfg.tracer,
resolver: cfg.HostNameResolver,
},
displayLogin: cfg.DisplayLogin,
term: cfg.Term,
Expand All @@ -161,6 +162,9 @@ type TerminalHandlerConfig struct {
// Requests that should be made on behalf of the user should
// use [UserAuthClient].
LocalAccessPoint localAccessPoint
// HostNameResolver allows the hostname to be determined from a server UUID
// so that a friendly name can be displayed in the console tab.
HostNameResolver func(serverID string) (hostname string, err error)
// DisplayLogin is the login name to display in the UI.
DisplayLogin string
// SessionData is the data to send to the client on the initial session creation.
Expand Down Expand Up @@ -275,14 +279,15 @@ type sshBaseHandler struct {
localAccessPoint localAccessPoint
// interactiveCommand is a command to execute.
interactiveCommand []string
// resolver looks up the hostname for the server UUID.
resolver func(serverID string) (hostname string, err error)
}

// localAccessPoint is a subset of the cache used to look up
// various cluster details.
type localAccessPoint interface {
GetUser(ctx context.Context, username string, withSecrets bool) (types.User, error)
GetRole(ctx context.Context, name string) (types.Role, error)
GetNode(ctx context.Context, namespace, name string) (types.Server, error)
}

// TerminalHandler connects together an SSH session with a web-based
Expand Down Expand Up @@ -370,11 +375,12 @@ func (t *TerminalHandler) writeSessionData(ctx context.Context) error {
// not be ok since this bypasses user RBAC, however, since at this point we have already
// established a connection to the target host via the user identity, the user MUST have
// access to the target host.
server, err := t.localAccessPoint.GetNode(ctx, apidefaults.Namespace, sessionDataTemp.ServerID)

hostname, err := t.resolver(sessionDataTemp.ServerID)
if err != nil {
return trace.Wrap(err)
}
sessionDataTemp.ServerHostname = server.GetHostname()
sessionDataTemp.ServerHostname = hostname

sessionMetadataResponse, err := json.Marshal(siteSessionGenerateResponse{Session: sessionDataTemp})
if err != nil {
Expand Down