From fa1393bea7a6ef1fe294ec6d9a13204eb8a2d7ae Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Tue, 28 Mar 2023 17:09:02 -0400 Subject: [PATCH 1/5] Reduce time spent setting ssh session envs `tsh` sets a number of environment variables when setting up the users session. Each key value pair is transmitted one at a time in a "env" ssh request, which adds a num envs * RTT of additional latency per session. This introduces a new `envs@goteleport.com` request which sets multiple environment variables in a single ssh request, which reduces the amount of time spent setting envs down to the RTT of a single ssh request. In order to ensure backward compat and interoperability with OpenSSH, if the server does not recognize the `envs@goteleport.com` request the ssh client will resort to sending individual "env" requests. --- api/observability/tracing/ssh/client_test.go | 162 +++++++++++++++++++ api/observability/tracing/ssh/session.go | 61 +++++++ api/observability/tracing/ssh/ssh.go | 12 ++ lib/client/session.go | 22 +-- lib/srv/forward/sshserver.go | 26 ++- lib/srv/regular/sshserver.go | 31 +++- lib/srv/term.go | 6 +- 7 files changed, 301 insertions(+), 19 deletions(-) diff --git a/api/observability/tracing/ssh/client_test.go b/api/observability/tracing/ssh/client_test.go index 1d17fa5d84d5d..ad0aaa3613a04 100644 --- a/api/observability/tracing/ssh/client_test.go +++ b/api/observability/tracing/ssh/client_test.go @@ -16,6 +16,7 @@ package ssh import ( "context" + "encoding/json" "fmt" "testing" @@ -262,3 +263,164 @@ func TestNewSession(t *testing.T) { }) } } + +// envReqParams are parameters for env request +type envReqParams struct { + Name string + Value string +} + +// processEnvRequest unmarshals the env request and validates that the +// received k,v match the provided values. Any mismatch or failure to +// process the message results in sending a reply of false. +func processEnvRequest(req *ssh.Request) (string, string) { + var e envReqParams + if err := ssh.Unmarshal(req.Payload, &e); err != nil { + _ = req.Reply(false, []byte(err.Error())) + return "", "" + } + + _ = req.Reply(true, nil) + + return e.Name, e.Value +} + +// TestSetEnvs verifies that client uses EnvsRequest to +// send multiple envs and falls back to sending individual "env" +// requests if the server does not support EnvsRequests. +func TestSetEnvs(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + errChan := make(chan error, 5) + + expected := map[string]string{"a": "1", "b": "2", "c": "3"} + + srv := newServer(t, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) { + for { + select { + case <-ctx.Done(): + return + + case ch := <-channels: + switch { + case ch == nil: + return + case ch.ChannelType() == "session": + ch, reqs, err := ch.Accept() + if err != nil { + errChan <- trace.Wrap(err, "failed to accept session channel") + return + } + + go func() { + // used to collect individual envs requests + fallback := map[string]string{} + + defer ch.Close() + for i := 0; ; i++ { + select { + case <-ctx.Done(): + return + + case req := <-reqs: + if req == nil { + return + } + + switch { + case i == 0 && req.Type == EnvsRequest: + var envReq EnvsReq + if err := ssh.Unmarshal(req.Payload, &envReq); err != nil { + _ = req.Reply(false, []byte(err.Error())) + return + } + + var envs map[string]string + if err := json.Unmarshal(envReq.Envs, &envs); err != nil { + _ = req.Reply(false, []byte(err.Error())) + return + } + + for k, v := range expected { + actual, ok := envs[k] + if !ok { + _ = req.Reply(false, []byte(fmt.Sprintf("expected env %s not present", k))) + return + } + + if actual != v { + _ = req.Reply(false, []byte(fmt.Sprintf("expected value %s for env %s, got %s", v, k, actual))) + return + } + } + + _ = req.Reply(true, nil) + case i == 1 && req.Type == EnvsRequest: + _ = req.Reply(false, nil) + case i == 2 && req.Type == "env": + k, v := processEnvRequest(req) + fallback[k] = v + case i == 3 && req.Type == "env": + k, v := processEnvRequest(req) + fallback[k] = v + case i == 4 && req.Type == "env": + k, v := processEnvRequest(req) + fallback[k] = v + + for k, v := range expected { + actual, ok := fallback[k] + if !ok { + _ = req.Reply(false, []byte(fmt.Sprintf("expected env %s not present", k))) + return + } + + if actual != v { + _ = req.Reply(false, []byte(fmt.Sprintf("expected value %s for env %s, got %s", v, k, actual))) + return + } + } + default: + _ = req.Reply(false, []byte(fmt.Sprintf("unexpected ssh request %s on iteration %d", req.Type, i))) + return + + } + } + } + + }() + + default: + if err := ch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unexpected channel %s", ch.ChannelType())); err != nil { + errChan <- err + return + } + } + } + } + }) + + go srv.Run(errChan) + + // create a client and open a session + conn, chans, reqs := srv.GetClient(t) + client := NewClient(conn, chans, reqs) + session, err := client.NewSession(ctx) + require.NoError(t, err) + + // the first request shouldn't fall back + t.Run("envs set via envs@goteleport.com", func(t *testing.T) { + require.NoError(t, session.SetEnvs(ctx, expected)) + }) + + // subsequent requests should fall back to standard "env" requests + t.Run("envs set individually", func(t *testing.T) { + require.NoError(t, session.SetEnvs(ctx, expected)) + }) + + select { + case err := <-errChan: + require.NoError(t, err) + default: + } +} diff --git a/api/observability/tracing/ssh/session.go b/api/observability/tracing/ssh/session.go index 5b53b503fe6af..0959bb69447c7 100644 --- a/api/observability/tracing/ssh/session.go +++ b/api/observability/tracing/ssh/session.go @@ -16,7 +16,9 @@ package ssh import ( "context" + "encoding/json" "fmt" + "strings" "go.opentelemetry.io/otel/attribute" semconv "go.opentelemetry.io/otel/semconv/v1.10.0" @@ -75,6 +77,65 @@ func (s *Session) Setenv(ctx context.Context, name, value string) error { return s.Session.Setenv(name, value) } +// SetEnvs sets environment variables that will be applied to any +// command executed by Shell or Run. If the server does not handle +// [EnvsRequest] requests then the client falls back to sending individual +// "env" requests until all provided environment variables have been set +// or an error was received. +func (s *Session) SetEnvs(ctx context.Context, envs map[string]string) error { + if len(envs) == 0 { + return nil + } + + // If the server isn't Teleport fallback to individual "env" requests + if !strings.HasPrefix(string(s.wrapper.ServerVersion()), "SSH-2.0-Teleport") { + return s.setEnvFallback(ctx, envs) + } + + config := tracing.NewConfig(s.wrapper.opts) + ctx, span := config.TracerProvider.Tracer(instrumentationName).Start( + ctx, + "ssh.SetEnvs", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + oteltrace.WithAttributes( + semconv.RPCServiceKey.String("ssh.Session"), + semconv.RPCMethodKey.String("SendRequest"), + semconv.RPCSystemKey.String("ssh"), + ), + ) + defer span.End() + + raw, err := json.Marshal(envs) + if err != nil { + return err + } + + s.wrapper.addContext(ctx, EnvsRequest) + ok, err := s.Session.SendRequest(EnvsRequest, true, ssh.Marshal(EnvsReq{Envs: raw})) + if err != nil { + return err + } + + // The server does not handle EnvsRequest requests so fall back + // to sending individual requests. + if !ok { + return s.setEnvFallback(ctx, envs) + } + + return nil +} + +// setEnvFallback sends an "env" request for each item in envs. +func (s *Session) setEnvFallback(ctx context.Context, envs map[string]string) error { + for k, v := range envs { + if err := s.Setenv(ctx, k, v); err != nil { + return err + } + } + + return nil +} + // RequestPty requests the association of a pty with the session on the remote host. func (s *Session) RequestPty(ctx context.Context, term string, h, w int, termmodes ssh.TerminalModes) error { const request = "pty-req" diff --git a/api/observability/tracing/ssh/ssh.go b/api/observability/tracing/ssh/ssh.go index cab328923e948..76d224181de1f 100644 --- a/api/observability/tracing/ssh/ssh.go +++ b/api/observability/tracing/ssh/ssh.go @@ -34,6 +34,10 @@ import ( ) const ( + // EnvsRequest sets multiple environment variables that will be applied to any + // command executed by Shell or Run. + EnvsRequest = "envs@goteleport.com" + // TracingRequest is sent by clients to server to pass along tracing context. TracingRequest = "tracing@goteleport.com" @@ -44,6 +48,14 @@ const ( instrumentationName = "otelssh" ) +// EnvsReq contains json marshaled key:value pairs sent as the +// payload for an EnvsRequest. +type EnvsReq struct { + // Envs is a json marshaled map[string]string containing + // environment variables. + Envs []byte +} + // ContextFromRequest extracts any tracing data provided via an Envelope // in the ssh.Request payload. If the payload contains an Envelope, then // the context returned will have tracing data populated from the remote diff --git a/lib/client/session.go b/lib/client/session.go index a0e95df6e0bc1..71f4584bddf2d 100644 --- a/lib/client/session.go +++ b/lib/client/session.go @@ -225,22 +225,22 @@ func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Sessi return nil, trace.Wrap(err) } + envs := map[string]string{} + // pass language info into the remote session. - evarsToPass := []string{"LANG", "LANGUAGE"} - for _, evar := range evarsToPass { - if value := os.Getenv(evar); value != "" { - err = sess.Setenv(ctx, evar, value) - if err != nil { - log.Warn(err) - } + langVars := []string{"LANG", "LANGUAGE"} + for _, env := range langVars { + if value := os.Getenv(env); value != "" { + envs[env] = value } } // pass environment variables set by client for key, val := range ns.env { - err = sess.Setenv(ctx, key, val) - if err != nil { - log.Warn(err) - } + envs[key] = val + } + + if err := sess.SetEnvs(ctx, envs); err != nil { + log.Warn(err) } // if agent forwarding was requested (and we have a agent to forward), diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 7d2d6817480ad..1fb9912552c4e 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -18,6 +18,7 @@ package forward import ( "context" + "encoding/json" "fmt" "io" "net" @@ -1046,7 +1047,7 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, return s.termHandlers.HandleWinChange(ctx, ch, req, scx) case teleport.ForceTerminateRequest: return s.termHandlers.HandleForceTerminate(ch, req, scx) - case sshutils.EnvRequest: + case sshutils.EnvRequest, tracessh.EnvsRequest: // We ignore all SSH setenv requests for join-only principals. // SSH will send them anyway but it seems fine to silently drop them. case sshutils.SubsystemRequest: @@ -1082,6 +1083,8 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, return s.termHandlers.HandleForceTerminate(ch, req, scx) case sshutils.EnvRequest: return s.handleEnv(ctx, ch, req, scx) + case tracessh.EnvsRequest: + return s.handleEnvs(ctx, ch, req, scx) case sshutils.SubsystemRequest: return s.handleSubsystem(ctx, ch, req, scx) case sshutils.X11ForwardRequest: @@ -1274,6 +1277,27 @@ func (s *Server) handleEnv(ctx context.Context, ch ssh.Channel, req *ssh.Request return nil } +// handleEnvs accepts environment variables sent by the client and forwards them +// to the remote session. +func (s *Server) handleEnvs(ctx context.Context, ch ssh.Channel, req *ssh.Request, scx *srv.ServerContext) error { + var raw tracessh.EnvsReq + if err := ssh.Unmarshal(req.Payload, &raw); err != nil { + scx.Error(err) + return trace.Wrap(err, "failed to parse envs request") + } + + var envs map[string]string + if err := json.Unmarshal(raw.Envs, &envs); err != nil { + return trace.Wrap(err, "failed to unmarshal envs") + } + + if err := scx.RemoteSession.SetEnvs(ctx, envs); err != nil { + s.log.Debugf("Unable to set environment variables: %v", err) + } + + return nil +} + func (s *Server) replyError(ch ssh.Channel, req *ssh.Request, err error) { s.log.WithError(err).Errorf("failure handling SSH %q request", req.Type) // Terminate the error with a newline when writing to remote channel's diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 18590157a60e9..63df0eb974914 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -20,6 +20,7 @@ package regular import ( "context" + "encoding/json" "fmt" "io" "net" @@ -1544,8 +1545,9 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, case sshutils.SubsystemRequest: return s.handleSubsystem(ctx, ch, req, serverContext) case sshutils.EnvRequest: - // we currently ignore setting any environment variables via SSH for security purposes return s.handleEnv(ch, req, serverContext) + case tracessh.EnvsRequest: + return s.handleEnvs(ch, req, serverContext) case sshutils.AgentForwardRequest: // process agent forwarding, but we will only forward agent to proxy in // recording proxy mode. @@ -1581,7 +1583,7 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, return s.termHandlers.HandleWinChange(ctx, ch, req, serverContext) case teleport.ForceTerminateRequest: return s.termHandlers.HandleForceTerminate(ch, req, serverContext) - case sshutils.EnvRequest: + case sshutils.EnvRequest, tracessh.EnvsRequest: // We ignore all SSH setenv requests for join-only principals. // SSH will send them anyway but it seems fine to silently drop them. case sshutils.SubsystemRequest: @@ -1629,6 +1631,8 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, return s.termHandlers.HandleForceTerminate(ch, req, serverContext) case sshutils.EnvRequest: return s.handleEnv(ch, req, serverContext) + case tracessh.EnvsRequest: + return s.handleEnvs(ch, req, serverContext) case sshutils.SubsystemRequest: // subsystems are SSH subsystems defined in http://tools.ietf.org/html/rfc4254 6.6 // they are in essence SSH session extensions, allowing to implement new SSH commands @@ -1795,7 +1799,7 @@ func (s *Server) handleSubsystem(ctx context.Context, ch ssh.Channel, req *ssh.R return nil } -// handleEnv accepts environment variables sent by the client and stores them +// handleEnv accepts an environment variable sent by the client and stores it // in connection context func (s *Server) handleEnv(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerContext) error { var e sshutils.EnvReqParams @@ -1807,6 +1811,27 @@ func (s *Server) handleEnv(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerCont return nil } +// handleEnvs accepts environment variables sent by the client and stores them +// in connection context +func (s *Server) handleEnvs(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerContext) error { + var raw tracessh.EnvsReq + if err := ssh.Unmarshal(req.Payload, &raw); err != nil { + ctx.Error(err) + return trace.Wrap(err, "failed to parse envs request") + } + + var envs map[string]string + if err := json.Unmarshal(raw.Envs, &envs); err != nil { + return trace.Wrap(err, "failed to unmarshal envs") + } + + for k, v := range envs { + ctx.SetEnv(k, v) + } + + return nil +} + // handleKeepAlive accepts and replies to keepalive@openssh.com requests. func (s *Server) handleKeepAlive(req *ssh.Request) { s.Logger.Debugf("Received %q: WantReply: %v", req.Type, req.WantReply) diff --git a/lib/srv/term.go b/lib/srv/term.go index 16d179ef491a7..5234ca15767e3 100644 --- a/lib/srv/term.go +++ b/lib/srv/term.go @@ -693,9 +693,7 @@ func (t *remoteTerminal) prepareRemoteSession(ctx context.Context, session *trac teleport.SSHSessionID: string(scx.SessionID()), } - for k, v := range envs { - if err := session.Setenv(ctx, k, v); err != nil { - t.log.Debugf("Unable to set environment variable: %v: %v", k, v) - } + if err := session.SetEnvs(ctx, envs); err != nil { + t.log.Debugf("Unable to set environment variables: %v", err) } } From 6c531350a01c5e2efa9b7359b67e528534d911a6 Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Wed, 29 Mar 2023 16:25:26 -0400 Subject: [PATCH 2/5] address feedback --- api/observability/tracing/ssh/client_test.go | 87 +++++++++----------- api/observability/tracing/ssh/session.go | 54 ++++++------ api/observability/tracing/ssh/ssh.go | 7 +- lib/srv/forward/sshserver.go | 4 +- lib/srv/regular/sshserver.go | 2 +- lib/srv/regular/sshserver_test.go | 35 +++++++- lib/srv/term.go | 2 +- 7 files changed, 109 insertions(+), 82 deletions(-) diff --git a/api/observability/tracing/ssh/client_test.go b/api/observability/tracing/ssh/client_test.go index ad0aaa3613a04..a30b0ac6ef33f 100644 --- a/api/observability/tracing/ssh/client_test.go +++ b/api/observability/tracing/ssh/client_test.go @@ -19,6 +19,7 @@ import ( "encoding/json" "fmt" "testing" + "time" "github.com/gravitational/trace" "github.com/stretchr/testify/require" @@ -270,21 +271,6 @@ type envReqParams struct { Value string } -// processEnvRequest unmarshals the env request and validates that the -// received k,v match the provided values. Any mismatch or failure to -// process the message results in sending a reply of false. -func processEnvRequest(req *ssh.Request) (string, string) { - var e envReqParams - if err := ssh.Unmarshal(req.Payload, &e); err != nil { - _ = req.Reply(false, []byte(err.Error())) - return "", "" - } - - _ = req.Reply(true, nil) - - return e.Name, e.Value -} - // TestSetEnvs verifies that client uses EnvsRequest to // send multiple envs and falls back to sending individual "env" // requests if the server does not support EnvsRequests. @@ -296,12 +282,14 @@ func TestSetEnvs(t *testing.T) { expected := map[string]string{"a": "1", "b": "2", "c": "3"} + // used to collect individual envs requests + envReqC := make(chan envReqParams, 3) + srv := newServer(t, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) { for { select { case <-ctx.Done(): return - case ch := <-channels: switch { case ch == nil: @@ -314,22 +302,18 @@ func TestSetEnvs(t *testing.T) { } go func() { - // used to collect individual envs requests - fallback := map[string]string{} - defer ch.Close() for i := 0; ; i++ { select { case <-ctx.Done(): return - case req := <-reqs: if req == nil { return } switch { - case i == 0 && req.Type == EnvsRequest: + case i == 0 && req.Type == EnvsRequest: // accept 1st EnvsRequest var envReq EnvsReq if err := ssh.Unmarshal(req.Payload, &envReq); err != nil { _ = req.Reply(false, []byte(err.Error())) @@ -337,7 +321,7 @@ func TestSetEnvs(t *testing.T) { } var envs map[string]string - if err := json.Unmarshal(envReq.Envs, &envs); err != nil { + if err := json.Unmarshal(envReq.EnvsJSON, &envs); err != nil { _ = req.Reply(false, []byte(err.Error())) return } @@ -356,40 +340,24 @@ func TestSetEnvs(t *testing.T) { } _ = req.Reply(true, nil) - case i == 1 && req.Type == EnvsRequest: + case i == 1 && req.Type == EnvsRequest: // reject additional EnvsRequest so we test fallbacks _ = req.Reply(false, nil) - case i == 2 && req.Type == "env": - k, v := processEnvRequest(req) - fallback[k] = v - case i == 3 && req.Type == "env": - k, v := processEnvRequest(req) - fallback[k] = v - case i == 4 && req.Type == "env": - k, v := processEnvRequest(req) - fallback[k] = v - - for k, v := range expected { - actual, ok := fallback[k] - if !ok { - _ = req.Reply(false, []byte(fmt.Sprintf("expected env %s not present", k))) - return - } - - if actual != v { - _ = req.Reply(false, []byte(fmt.Sprintf("expected value %s for env %s, got %s", v, k, actual))) - return - } + case i >= 2 && i <= len(expected)+2 && req.Type == "env": // accept individual "env" fallbacks. + var e envReqParams + if err := ssh.Unmarshal(req.Payload, &e); err != nil { + _ = req.Reply(false, []byte(err.Error())) + return } - default: + envReqC <- e + _ = req.Reply(true, nil) + default: // out of order or unexpected message _ = req.Reply(false, []byte(fmt.Sprintf("unexpected ssh request %s on iteration %d", req.Type, i))) + errChan <- err return - } } } - }() - default: if err := ch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unexpected channel %s", ch.ChannelType())); err != nil { errChan <- err @@ -411,11 +379,34 @@ func TestSetEnvs(t *testing.T) { // the first request shouldn't fall back t.Run("envs set via envs@goteleport.com", func(t *testing.T) { require.NoError(t, session.SetEnvs(ctx, expected)) + + select { + case <-envReqC: + t.Fatal("env request received instead of an envs@goteleport.com request") + default: + } }) // subsequent requests should fall back to standard "env" requests t.Run("envs set individually", func(t *testing.T) { require.NoError(t, session.SetEnvs(ctx, expected)) + + envs := map[string]string{} + for i := 0; i < len(expected); i++ { + select { + case env := <-envReqC: + envs[env.Name] = env.Value + case <-time.After(3 * time.Second): + t.Fatalf("time out waiting for env request %d to be processed", i) + } + } + + for k, v := range expected { + actual, ok := envs[k] + require.True(t, ok, "expected env %s to be set", k) + + require.Equal(t, v, actual, "expected value %s for env %s, got %s", v, k, actual) + } }) select { diff --git a/api/observability/tracing/ssh/session.go b/api/observability/tracing/ssh/session.go index 0959bb69447c7..bacb0b7855cb3 100644 --- a/api/observability/tracing/ssh/session.go +++ b/api/observability/tracing/ssh/session.go @@ -20,6 +20,7 @@ import ( "fmt" "strings" + "github.com/gravitational/trace" "go.opentelemetry.io/otel/attribute" semconv "go.opentelemetry.io/otel/semconv/v1.10.0" oteltrace "go.opentelemetry.io/otel/trace" @@ -53,7 +54,8 @@ func (s *Session) SendRequest(ctx context.Context, name string, wantReply bool, // no need to wrap payload here, the session's channel wrapper will do it for us s.wrapper.addContext(ctx, name) - return s.Session.SendRequest(name, wantReply, payload) + ok, err := s.Session.SendRequest(name, wantReply, payload) + return ok, trace.Wrap(err) } // Setenv sets an environment variable that will be applied to any @@ -74,7 +76,7 @@ func (s *Session) Setenv(ctx context.Context, name, value string) error { defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.Setenv(name, value) + return trace.Wrap(s.Session.Setenv(name, value)) } // SetEnvs sets environment variables that will be applied to any @@ -83,15 +85,6 @@ func (s *Session) Setenv(ctx context.Context, name, value string) error { // "env" requests until all provided environment variables have been set // or an error was received. func (s *Session) SetEnvs(ctx context.Context, envs map[string]string) error { - if len(envs) == 0 { - return nil - } - - // If the server isn't Teleport fallback to individual "env" requests - if !strings.HasPrefix(string(s.wrapper.ServerVersion()), "SSH-2.0-Teleport") { - return s.setEnvFallback(ctx, envs) - } - config := tracing.NewConfig(s.wrapper.opts) ctx, span := config.TracerProvider.Tracer(instrumentationName).Start( ctx, @@ -105,21 +98,30 @@ func (s *Session) SetEnvs(ctx context.Context, envs map[string]string) error { ) defer span.End() + if len(envs) == 0 { + return nil + } + + // If the server isn't Teleport fallback to individual "env" requests + if !strings.HasPrefix(string(s.wrapper.ServerVersion()), "SSH-2.0-Teleport") { + return trace.Wrap(s.setEnvFallback(ctx, envs)) + } + raw, err := json.Marshal(envs) if err != nil { - return err + return trace.Wrap(err) } s.wrapper.addContext(ctx, EnvsRequest) - ok, err := s.Session.SendRequest(EnvsRequest, true, ssh.Marshal(EnvsReq{Envs: raw})) + ok, err := s.Session.SendRequest(EnvsRequest, true, ssh.Marshal(EnvsReq{EnvsJSON: raw})) if err != nil { - return err + return trace.Wrap(err) } // The server does not handle EnvsRequest requests so fall back // to sending individual requests. if !ok { - return s.setEnvFallback(ctx, envs) + return trace.Wrap(s.setEnvFallback(ctx, envs)) } return nil @@ -129,7 +131,7 @@ func (s *Session) SetEnvs(ctx context.Context, envs map[string]string) error { func (s *Session) setEnvFallback(ctx context.Context, envs map[string]string) error { for k, v := range envs { if err := s.Setenv(ctx, k, v); err != nil { - return err + return trace.Wrap(err, "failed to set environment variable %s", k) } } @@ -156,7 +158,7 @@ func (s *Session) RequestPty(ctx context.Context, term string, h, w int, termmod defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.RequestPty(term, h, w, termmodes) + return trace.Wrap(s.Session.RequestPty(term, h, w, termmodes)) } // RequestSubsystem requests the association of a subsystem with the session on the remote host. @@ -177,7 +179,7 @@ func (s *Session) RequestSubsystem(ctx context.Context, subsystem string) error defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.RequestSubsystem(subsystem) + return trace.Wrap(s.Session.RequestSubsystem(subsystem)) } // WindowChange informs the remote host about a terminal window dimension change to h rows and w columns. @@ -199,7 +201,7 @@ func (s *Session) WindowChange(ctx context.Context, h, w int) error { defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.WindowChange(h, w) + return trace.Wrap(s.Session.WindowChange(h, w)) } // Signal sends the given signal to the remote process. @@ -220,7 +222,7 @@ func (s *Session) Signal(ctx context.Context, sig ssh.Signal) error { defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.Signal(sig) + return trace.Wrap(s.Session.Signal(sig)) } // Start runs cmd on the remote host. Typically, the remote @@ -242,7 +244,7 @@ func (s *Session) Start(ctx context.Context, cmd string) error { defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.Start(cmd) + return trace.Wrap(s.Session.Start(cmd)) } // Shell starts a login shell on the remote host. A Session only @@ -263,7 +265,7 @@ func (s *Session) Shell(ctx context.Context) error { defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.Shell() + return trace.Wrap(s.Session.Shell()) } // Run runs cmd on the remote host. Typically, the remote @@ -295,7 +297,7 @@ func (s *Session) Run(ctx context.Context, cmd string) error { defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.Run(cmd) + return trace.Wrap(s.Session.Run(cmd)) } // Output runs cmd on the remote host and returns its standard output. @@ -315,7 +317,8 @@ func (s *Session) Output(ctx context.Context, cmd string) ([]byte, error) { defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.Output(cmd) + output, err := s.Session.Output(cmd) + return output, trace.Wrap(err) } // CombinedOutput runs cmd on the remote host and returns its combined @@ -336,5 +339,6 @@ func (s *Session) CombinedOutput(ctx context.Context, cmd string) ([]byte, error defer span.End() s.wrapper.addContext(ctx, request) - return s.Session.CombinedOutput(cmd) + output, err := s.Session.CombinedOutput(cmd) + return output, trace.Wrap(err) } diff --git a/api/observability/tracing/ssh/ssh.go b/api/observability/tracing/ssh/ssh.go index 76d224181de1f..efe206a42f3ab 100644 --- a/api/observability/tracing/ssh/ssh.go +++ b/api/observability/tracing/ssh/ssh.go @@ -36,6 +36,7 @@ import ( const ( // EnvsRequest sets multiple environment variables that will be applied to any // command executed by Shell or Run. + // See [EnvsReq] for the corresponding payload. EnvsRequest = "envs@goteleport.com" // TracingRequest is sent by clients to server to pass along tracing context. @@ -49,11 +50,11 @@ const ( ) // EnvsReq contains json marshaled key:value pairs sent as the -// payload for an EnvsRequest. +// payload for an [EnvsRequest]. type EnvsReq struct { - // Envs is a json marshaled map[string]string containing + // EnvsJSON is a json marshaled map[string]string containing // environment variables. - Envs []byte + EnvsJSON []byte `json:"envs"` } // ContextFromRequest extracts any tracing data provided via an Envelope diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 1fb9912552c4e..2b1bbe867750b 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -1287,12 +1287,12 @@ func (s *Server) handleEnvs(ctx context.Context, ch ssh.Channel, req *ssh.Reques } var envs map[string]string - if err := json.Unmarshal(raw.Envs, &envs); err != nil { + if err := json.Unmarshal(raw.EnvsJSON, &envs); err != nil { return trace.Wrap(err, "failed to unmarshal envs") } if err := scx.RemoteSession.SetEnvs(ctx, envs); err != nil { - s.log.Debugf("Unable to set environment variables: %v", err) + s.log.WithError(err).Debug("Unable to set environment variables") } return nil diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 63df0eb974914..b346136ca5244 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -1821,7 +1821,7 @@ func (s *Server) handleEnvs(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerCon } var envs map[string]string - if err := json.Unmarshal(raw.Envs, &envs); err != nil { + if err := json.Unmarshal(raw.EnvsJSON, &envs); err != nil { return trace.Wrap(err, "failed to unmarshal envs") } diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index 92b11d39ae4b8..3b6025e1a6e0a 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -1649,7 +1649,8 @@ func TestPTY(t *testing.T) { require.NoError(t, se.RequestPty(ctx, "xterm", 0, 0, ssh.TerminalModes{})) } -// TestEnv requests setting environment variables. (We are currently ignoring these requests) +// TestEnv requests setting environment variables via +// a "env" request. func TestEnv(t *testing.T) { t.Parallel() ctx := context.Background() @@ -1660,7 +1661,37 @@ func TestEnv(t *testing.T) { require.NoError(t, err) defer se.Close() - require.NoError(t, se.Setenv(ctx, "HOME", "/")) + require.NoError(t, se.Setenv(ctx, "HOME_TEST", "/test")) + output, err := se.Output(ctx, "env") + require.NoError(t, err) + require.Contains(t, string(output), "HOME_TEST=/test") +} + +// TestEnvs requests setting environment variables via +// a "envs@goteleport.com" request. +func TestEnvs(t *testing.T) { + t.Parallel() + ctx := context.Background() + + f := newFixtureWithoutDiskBasedLogging(t) + + se, err := f.ssh.clt.NewSession(ctx) + require.NoError(t, err) + defer se.Close() + + envs := map[string]string{ + "HOME_TEST": "/test", + "LLAMA": "ALPACA", + "FISH": "FROG", + } + + require.NoError(t, se.SetEnvs(ctx, envs)) + output, err := se.Output(ctx, "env") + require.NoError(t, err) + + for k, v := range envs { + require.Contains(t, string(output), k+"="+v) + } } // TestNoAuth tries to log in with no auth methods and should be rejected diff --git a/lib/srv/term.go b/lib/srv/term.go index 5234ca15767e3..c714bd82fed63 100644 --- a/lib/srv/term.go +++ b/lib/srv/term.go @@ -694,6 +694,6 @@ func (t *remoteTerminal) prepareRemoteSession(ctx context.Context, session *trac } if err := session.SetEnvs(ctx, envs); err != nil { - t.log.Debugf("Unable to set environment variables: %v", err) + t.log.WithError(err).Debug("Unable to set environment variables") } } From 8683fb9baf27dd5e98c8c1b1c604644c861066b1 Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Wed, 29 Mar 2023 17:46:24 -0400 Subject: [PATCH 3/5] fix: use a single timer for fallback requests in tests Co-authored-by: Alan Parra --- api/observability/tracing/ssh/client_test.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/api/observability/tracing/ssh/client_test.go b/api/observability/tracing/ssh/client_test.go index a30b0ac6ef33f..41f2ab95032cc 100644 --- a/api/observability/tracing/ssh/client_test.go +++ b/api/observability/tracing/ssh/client_test.go @@ -392,12 +392,14 @@ func TestSetEnvs(t *testing.T) { require.NoError(t, session.SetEnvs(ctx, expected)) envs := map[string]string{} + envsTimeout := time.NewTimer(3*time.Second) + defer envsTimeout.Stop() for i := 0; i < len(expected); i++ { select { case env := <-envReqC: envs[env.Name] = env.Value - case <-time.After(3 * time.Second): - t.Fatalf("time out waiting for env request %d to be processed", i) + case <-envsTimeout.C: + t.Fatalf("Time out waiting for env request %d to be processed", i) } } From 5b4818e141e66d4269bb3485f244e12c09d001d5 Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Wed, 29 Mar 2023 17:46:40 -0400 Subject: [PATCH 4/5] fix: remove extra whitespace Co-authored-by: Alan Parra --- api/observability/tracing/ssh/client_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/api/observability/tracing/ssh/client_test.go b/api/observability/tracing/ssh/client_test.go index 41f2ab95032cc..b1eb4fba3a26b 100644 --- a/api/observability/tracing/ssh/client_test.go +++ b/api/observability/tracing/ssh/client_test.go @@ -406,7 +406,6 @@ func TestSetEnvs(t *testing.T) { for k, v := range expected { actual, ok := envs[k] require.True(t, ok, "expected env %s to be set", k) - require.Equal(t, v, actual, "expected value %s for env %s, got %s", v, k, actual) } }) From afb13f0a8cab68af521cb9c77171d6403b15b22f Mon Sep 17 00:00:00 2001 From: Tim Ross Date: Wed, 29 Mar 2023 17:55:08 -0400 Subject: [PATCH 5/5] fix: gci --- api/observability/tracing/ssh/client_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/observability/tracing/ssh/client_test.go b/api/observability/tracing/ssh/client_test.go index b1eb4fba3a26b..030df075ec9d9 100644 --- a/api/observability/tracing/ssh/client_test.go +++ b/api/observability/tracing/ssh/client_test.go @@ -392,7 +392,7 @@ func TestSetEnvs(t *testing.T) { require.NoError(t, session.SetEnvs(ctx, expected)) envs := map[string]string{} - envsTimeout := time.NewTimer(3*time.Second) + envsTimeout := time.NewTimer(3 * time.Second) defer envsTimeout.Stop() for i := 0; i < len(expected); i++ { select {