diff --git a/api/observability/tracing/ssh/client_test.go b/api/observability/tracing/ssh/client_test.go index 1d17fa5d84d5d..030df075ec9d9 100644 --- a/api/observability/tracing/ssh/client_test.go +++ b/api/observability/tracing/ssh/client_test.go @@ -16,8 +16,10 @@ package ssh import ( "context" + "encoding/json" "fmt" "testing" + "time" "github.com/gravitational/trace" "github.com/stretchr/testify/require" @@ -262,3 +264,155 @@ func TestNewSession(t *testing.T) { }) } } + +// envReqParams are parameters for env request +type envReqParams struct { + Name string + Value string +} + +// TestSetEnvs verifies that client uses EnvsRequest to +// send multiple envs and falls back to sending individual "env" +// requests if the server does not support EnvsRequests. +func TestSetEnvs(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + errChan := make(chan error, 5) + + expected := map[string]string{"a": "1", "b": "2", "c": "3"} + + // used to collect individual envs requests + envReqC := make(chan envReqParams, 3) + + srv := newServer(t, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) { + for { + select { + case <-ctx.Done(): + return + case ch := <-channels: + switch { + case ch == nil: + return + case ch.ChannelType() == "session": + ch, reqs, err := ch.Accept() + if err != nil { + errChan <- trace.Wrap(err, "failed to accept session channel") + return + } + + go func() { + defer ch.Close() + for i := 0; ; i++ { + select { + case <-ctx.Done(): + return + case req := <-reqs: + if req == nil { + return + } + + switch { + case i == 0 && req.Type == EnvsRequest: // accept 1st EnvsRequest + var envReq EnvsReq + if err := ssh.Unmarshal(req.Payload, &envReq); err != nil { + _ = req.Reply(false, []byte(err.Error())) + return + } + + var envs map[string]string + if err := json.Unmarshal(envReq.EnvsJSON, &envs); err != nil { + _ = req.Reply(false, []byte(err.Error())) + return + } + + for k, v := range expected { + actual, ok := envs[k] + if !ok { + _ = req.Reply(false, []byte(fmt.Sprintf("expected env %s not present", k))) + return + } + + if actual != v { + _ = req.Reply(false, []byte(fmt.Sprintf("expected value %s for env %s, got %s", v, k, actual))) + return + } + } + + _ = req.Reply(true, nil) + case i == 1 && req.Type == EnvsRequest: // reject additional EnvsRequest so we test fallbacks + _ = req.Reply(false, nil) + case i >= 2 && i <= len(expected)+2 && req.Type == "env": // accept individual "env" fallbacks. + var e envReqParams + if err := ssh.Unmarshal(req.Payload, &e); err != nil { + _ = req.Reply(false, []byte(err.Error())) + return + } + envReqC <- e + _ = req.Reply(true, nil) + default: // out of order or unexpected message + _ = req.Reply(false, []byte(fmt.Sprintf("unexpected ssh request %s on iteration %d", req.Type, i))) + errChan <- err + return + } + } + } + }() + default: + if err := ch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unexpected channel %s", ch.ChannelType())); err != nil { + errChan <- err + return + } + } + } + } + }) + + go srv.Run(errChan) + + // create a client and open a session + conn, chans, reqs := srv.GetClient(t) + client := NewClient(conn, chans, reqs) + session, err := client.NewSession(ctx) + require.NoError(t, err) + + // the first request shouldn't fall back + t.Run("envs set via envs@goteleport.com", func(t *testing.T) { + require.NoError(t, session.SetEnvs(ctx, expected)) + + select { + case <-envReqC: + t.Fatal("env request received instead of an envs@goteleport.com request") + default: + } + }) + + // subsequent requests should fall back to standard "env" requests + t.Run("envs set individually", func(t *testing.T) { + require.NoError(t, session.SetEnvs(ctx, expected)) + + envs := map[string]string{} + envsTimeout := time.NewTimer(3 * time.Second) + defer envsTimeout.Stop() + for i := 0; i < len(expected); i++ { + select { + case env := <-envReqC: + envs[env.Name] = env.Value + case <-envsTimeout.C: + t.Fatalf("Time out waiting for env request %d to be processed", i) + } + } + + for k, v := range expected { + actual, ok := envs[k] + require.True(t, ok, "expected env %s to be set", k) + require.Equal(t, v, actual, "expected value %s for env %s, got %s", v, k, actual) + } + }) + + select { + case err := <-errChan: + require.NoError(t, err) + default: + } +} diff --git a/api/observability/tracing/ssh/session.go b/api/observability/tracing/ssh/session.go index 5b53b503fe6af..bacb0b7855cb3 100644 --- a/api/observability/tracing/ssh/session.go +++ b/api/observability/tracing/ssh/session.go @@ -16,8 +16,11 @@ package ssh import ( "context" + "encoding/json" "fmt" + "strings" + "github.com/gravitational/trace" "go.opentelemetry.io/otel/attribute" semconv "go.opentelemetry.io/otel/semconv/v1.10.0" oteltrace "go.opentelemetry.io/otel/trace" @@ -51,7 +54,8 @@ func (s *Session) SendRequest(ctx context.Context, name string, wantReply bool, // no need to wrap payload here, the session's channel wrapper will do it for us s.wrapper.addContext(ctx, name) - return s.Session.SendRequest(name, wantReply, payload) + ok, err := s.Session.SendRequest(name, wantReply, payload) + return ok, trace.Wrap(err) } // Setenv sets an environment variable that will be applied to any @@ -72,7 +76,66 @@ func (s *Session) Setenv(ctx context.Context, name, value string) error { defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.Setenv(name, value) + return trace.Wrap(s.Session.Setenv(name, value)) +} + +// SetEnvs sets environment variables that will be applied to any +// command executed by Shell or Run. If the server does not handle +// [EnvsRequest] requests then the client falls back to sending individual +// "env" requests until all provided environment variables have been set +// or an error was received. +func (s *Session) SetEnvs(ctx context.Context, envs map[string]string) error { + config := tracing.NewConfig(s.wrapper.opts) + ctx, span := config.TracerProvider.Tracer(instrumentationName).Start( + ctx, + "ssh.SetEnvs", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.Session"), + semconv.RPCMethodKey.String("SendRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + defer span.End() + + if len(envs) == 0 { + return nil + } + + // If the server isn't Teleport fallback to individual "env" requests + if !strings.HasPrefix(string(s.wrapper.ServerVersion()), "SSH-2.0-Teleport") { + return trace.Wrap(s.setEnvFallback(ctx, envs)) + } + + raw, err := json.Marshal(envs) + if err != nil { + return trace.Wrap(err) + } + + s.wrapper.addContext(ctx, EnvsRequest) + ok, err := s.Session.SendRequest(EnvsRequest, true, ssh.Marshal(EnvsReq{EnvsJSON: raw})) + if err != nil { + return trace.Wrap(err) + } + + // The server does not handle EnvsRequest requests so fall back + // to sending individual requests. + if !ok { + return trace.Wrap(s.setEnvFallback(ctx, envs)) + } + + return nil +} + +// setEnvFallback sends an "env" request for each item in envs. +func (s *Session) setEnvFallback(ctx context.Context, envs map[string]string) error { + for k, v := range envs { + if err := s.Setenv(ctx, k, v); err != nil { + return trace.Wrap(err, "failed to set environment variable %s", k) + } + } + + return nil } // RequestPty requests the association of a pty with the session on the remote host. @@ -95,7 +158,7 @@ func (s *Session) RequestPty(ctx context.Context, term string, h, w int, termmod defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.RequestPty(term, h, w, termmodes) + return trace.Wrap(s.Session.RequestPty(term, h, w, termmodes)) } // RequestSubsystem requests the association of a subsystem with the session on the remote host. @@ -116,7 +179,7 @@ func (s *Session) RequestSubsystem(ctx context.Context, subsystem string) error defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.RequestSubsystem(subsystem) + return trace.Wrap(s.Session.RequestSubsystem(subsystem)) } // WindowChange informs the remote host about a terminal window dimension change to h rows and w columns. @@ -138,7 +201,7 @@ func (s *Session) WindowChange(ctx context.Context, h, w int) error { defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.WindowChange(h, w) + return trace.Wrap(s.Session.WindowChange(h, w)) } // Signal sends the given signal to the remote process. @@ -159,7 +222,7 @@ func (s *Session) Signal(ctx context.Context, sig ssh.Signal) error { defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.Signal(sig) + return trace.Wrap(s.Session.Signal(sig)) } // Start runs cmd on the remote host. Typically, the remote @@ -181,7 +244,7 @@ func (s *Session) Start(ctx context.Context, cmd string) error { defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.Start(cmd) + return trace.Wrap(s.Session.Start(cmd)) } // Shell starts a login shell on the remote host. A Session only @@ -202,7 +265,7 @@ func (s *Session) Shell(ctx context.Context) error { defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.Shell() + return trace.Wrap(s.Session.Shell()) } // Run runs cmd on the remote host. Typically, the remote @@ -234,7 +297,7 @@ func (s *Session) Run(ctx context.Context, cmd string) error { defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.Run(cmd) + return trace.Wrap(s.Session.Run(cmd)) } // Output runs cmd on the remote host and returns its standard output. @@ -254,7 +317,8 @@ func (s *Session) Output(ctx context.Context, cmd string) ([]byte, error) { defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.Output(cmd) + output, err := s.Session.Output(cmd) + return output, trace.Wrap(err) } // CombinedOutput runs cmd on the remote host and returns its combined @@ -275,5 +339,6 @@ func (s *Session) CombinedOutput(ctx context.Context, cmd string) ([]byte, error defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.CombinedOutput(cmd) + output, err := s.Session.CombinedOutput(cmd) + return output, trace.Wrap(err) } diff --git a/api/observability/tracing/ssh/ssh.go b/api/observability/tracing/ssh/ssh.go index cab328923e948..efe206a42f3ab 100644 --- a/api/observability/tracing/ssh/ssh.go +++ b/api/observability/tracing/ssh/ssh.go @@ -34,6 +34,11 @@ import ( ) const ( + // EnvsRequest sets multiple environment variables that will be applied to any + // command executed by Shell or Run. + // See [EnvsReq] for the corresponding payload. + EnvsRequest = "envs@goteleport.com" + // TracingRequest is sent by clients to server to pass along tracing context. TracingRequest = "tracing@goteleport.com" @@ -44,6 +49,14 @@ const ( instrumentationName = "otelssh" ) +// EnvsReq contains json marshaled key:value pairs sent as the +// payload for an [EnvsRequest]. +type EnvsReq struct { + // EnvsJSON is a json marshaled map[string]string containing + // environment variables. + EnvsJSON []byte `json:"envs"` +} + // ContextFromRequest extracts any tracing data provided via an Envelope // in the ssh.Request payload. If the payload contains an Envelope, then // the context returned will have tracing data populated from the remote diff --git a/lib/client/session.go b/lib/client/session.go index 586cd940f856a..6852b49cf64c6 100644 --- a/lib/client/session.go +++ b/lib/client/session.go @@ -226,22 +226,22 @@ func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Sessi return nil, trace.Wrap(err) } + envs := map[string]string{} + // pass language info into the remote session. - evarsToPass := []string{"LANG", "LANGUAGE"} - for _, evar := range evarsToPass { - if value := os.Getenv(evar); value != "" { - err = sess.Setenv(ctx, evar, value) - if err != nil { - log.Warn(err) - } + langVars := []string{"LANG", "LANGUAGE"} + for _, env := range langVars { + if value := os.Getenv(env); value != "" { + envs[env] = value } } // pass environment variables set by client for key, val := range ns.env { - err = sess.Setenv(ctx, key, val) - if err != nil { - log.Warn(err) - } + envs[key] = val + } + + if err := sess.SetEnvs(ctx, envs); err != nil { + log.Warn(err) } // if agent forwarding was requested (and we have a agent to forward), diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 2ed6cae6cc16c..8603ee2fb3a46 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -18,6 +18,7 @@ package forward import ( "context" + "encoding/json" "fmt" "io" "net" @@ -1054,7 +1055,7 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, return s.termHandlers.HandleWinChange(ctx, ch, req, scx) case teleport.ForceTerminateRequest: return s.termHandlers.HandleForceTerminate(ch, req, scx) - case sshutils.EnvRequest: + case sshutils.EnvRequest, tracessh.EnvsRequest: // We ignore all SSH setenv requests for join-only principals. // SSH will send them anyway but it seems fine to silently drop them. case sshutils.SubsystemRequest: @@ -1090,6 +1091,8 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, return s.termHandlers.HandleForceTerminate(ch, req, scx) case sshutils.EnvRequest: return s.handleEnv(ctx, ch, req, scx) + case tracessh.EnvsRequest: + return s.handleEnvs(ctx, ch, req, scx) case sshutils.SubsystemRequest: return s.handleSubsystem(ctx, ch, req, scx) case sshutils.X11ForwardRequest: @@ -1282,6 +1285,27 @@ func (s *Server) handleEnv(ctx context.Context, ch ssh.Channel, req *ssh.Request return nil } +// handleEnvs accepts environment variables sent by the client and forwards them +// to the remote session. +func (s *Server) handleEnvs(ctx context.Context, ch ssh.Channel, req *ssh.Request, scx *srv.ServerContext) error { + var raw tracessh.EnvsReq + if err := ssh.Unmarshal(req.Payload, &raw); err != nil { + scx.Error(err) + return trace.Wrap(err, "failed to parse envs request") + } + + var envs map[string]string + if err := json.Unmarshal(raw.EnvsJSON, &envs); err != nil { + return trace.Wrap(err, "failed to unmarshal envs") + } + + if err := scx.RemoteSession.SetEnvs(ctx, envs); err != nil { + s.log.WithError(err).Debug("Unable to set environment variables") + } + + return nil +} + func (s *Server) replyError(ch ssh.Channel, req *ssh.Request, err error) { s.log.WithError(err).Errorf("failure handling SSH %q request", req.Type) // Terminate the error with a newline when writing to remote channel's diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 42102f45dc73d..78633d4932e62 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -20,6 +20,7 @@ package regular import ( "context" + "encoding/json" "fmt" "io" "net" @@ -1567,8 +1568,9 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, case sshutils.SubsystemRequest: return s.handleSubsystem(ctx, ch, req, serverContext) case sshutils.EnvRequest: - // we currently ignore setting any environment variables via SSH for security purposes return s.handleEnv(ch, req, serverContext) + case tracessh.EnvsRequest: + return s.handleEnvs(ch, req, serverContext) case sshutils.AgentForwardRequest: // process agent forwarding, but we will only forward agent to proxy in // recording proxy mode. @@ -1604,7 +1606,7 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, return s.termHandlers.HandleWinChange(ctx, ch, req, serverContext) case teleport.ForceTerminateRequest: return s.termHandlers.HandleForceTerminate(ch, req, serverContext) - case sshutils.EnvRequest: + case sshutils.EnvRequest, tracessh.EnvsRequest: // We ignore all SSH setenv requests for join-only principals. // SSH will send them anyway but it seems fine to silently drop them. case sshutils.SubsystemRequest: @@ -1652,6 +1654,8 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, return s.termHandlers.HandleForceTerminate(ch, req, serverContext) case sshutils.EnvRequest: return s.handleEnv(ch, req, serverContext) + case tracessh.EnvsRequest: + return s.handleEnvs(ch, req, serverContext) case sshutils.SubsystemRequest: // subsystems are SSH subsystems defined in http://tools.ietf.org/html/rfc4254 6.6 // they are in essence SSH session extensions, allowing to implement new SSH commands @@ -1818,7 +1822,7 @@ func (s *Server) handleSubsystem(ctx context.Context, ch ssh.Channel, req *ssh.R return nil } -// handleEnv accepts environment variables sent by the client and stores them +// handleEnv accepts an environment variable sent by the client and stores it // in connection context func (s *Server) handleEnv(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerContext) error { var e sshutils.EnvReqParams @@ -1830,6 +1834,27 @@ func (s *Server) handleEnv(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerCont return nil } +// handleEnvs accepts environment variables sent by the client and stores them +// in connection context +func (s *Server) handleEnvs(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerContext) error { + var raw tracessh.EnvsReq + if err := ssh.Unmarshal(req.Payload, &raw); err != nil { + ctx.Error(err) + return trace.Wrap(err, "failed to parse envs request") + } + + var envs map[string]string + if err := json.Unmarshal(raw.EnvsJSON, &envs); err != nil { + return trace.Wrap(err, "failed to unmarshal envs") + } + + for k, v := range envs { + ctx.SetEnv(k, v) + } + + return nil +} + // handleKeepAlive accepts and replies to keepalive@openssh.com requests. func (s *Server) handleKeepAlive(req *ssh.Request) { s.Logger.Debugf("Received %q: WantReply: %v", req.Type, req.WantReply) diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 9bb15c337a85b..5e36d0c8c5e83 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -1651,7 +1651,8 @@ func TestPTY(t *testing.T) { require.NoError(t, se.RequestPty(ctx, "xterm", 0, 0, ssh.TerminalModes{})) } -// TestEnv requests setting environment variables. (We are currently ignoring these requests) +// TestEnv requests setting environment variables via +// a "env" request. func TestEnv(t *testing.T) { t.Parallel() ctx := context.Background() @@ -1662,7 +1663,37 @@ func TestEnv(t *testing.T) { require.NoError(t, err) defer se.Close() - require.NoError(t, se.Setenv(ctx, "HOME", "/")) + require.NoError(t, se.Setenv(ctx, "HOME_TEST", "/test")) + output, err := se.Output(ctx, "env") + require.NoError(t, err) + require.Contains(t, string(output), "HOME_TEST=/test") +} + +// TestEnvs requests setting environment variables via +// a "envs@goteleport.com" request. +func TestEnvs(t *testing.T) { + t.Parallel() + ctx := context.Background() + + f := newFixtureWithoutDiskBasedLogging(t) + + se, err := f.ssh.clt.NewSession(ctx) + require.NoError(t, err) + defer se.Close() + + envs := map[string]string{ + "HOME_TEST": "/test", + "LLAMA": "ALPACA", + "FISH": "FROG", + } + + require.NoError(t, se.SetEnvs(ctx, envs)) + output, err := se.Output(ctx, "env") + require.NoError(t, err) + + for k, v := range envs { + require.Contains(t, string(output), k+"="+v) + } } // TestNoAuth tries to log in with no auth methods and should be rejected diff --git a/lib/srv/term.go b/lib/srv/term.go index a7ef55406311b..d649034fbec9c 100644 --- a/lib/srv/term.go +++ b/lib/srv/term.go @@ -682,9 +682,7 @@ func (t *remoteTerminal) prepareRemoteSession(ctx context.Context, session *trac teleport.SSHSessionID: string(scx.SessionID()), } - for k, v := range envs { - if err := session.Setenv(ctx, k, v); err != nil { - t.log.Debugf("Unable to set environment variable: %v: %v", k, v) - } + if err := session.SetEnvs(ctx, envs); err != nil { + t.log.WithError(err).Debug("Unable to set environment variables") } }