Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions api/observability/tracing/ssh/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ package ssh

import (
"context"
"encoding/json"
"fmt"
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -262,3 +264,155 @@ func TestNewSession(t *testing.T) {
})
}
}

// envReqParams are parameters for env request
type envReqParams struct {
Name string
Value string
}

// TestSetEnvs verifies that client uses EnvsRequest to
// send multiple envs and falls back to sending individual "env"
// requests if the server does not support EnvsRequests.
func TestSetEnvs(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
errChan := make(chan error, 5)

expected := map[string]string{"a": "1", "b": "2", "c": "3"}

// used to collect individual envs requests
envReqC := make(chan envReqParams, 3)

srv := newServer(t, func(conn *ssh.ServerConn, channels <-chan ssh.NewChannel, requests <-chan *ssh.Request) {
for {
select {
case <-ctx.Done():
return
case ch := <-channels:
switch {
case ch == nil:
return
case ch.ChannelType() == "session":
ch, reqs, err := ch.Accept()
if err != nil {
errChan <- trace.Wrap(err, "failed to accept session channel")
return
}

go func() {
defer ch.Close()
for i := 0; ; i++ {
select {
case <-ctx.Done():
return
case req := <-reqs:
if req == nil {
return
}

switch {
case i == 0 && req.Type == EnvsRequest: // accept 1st EnvsRequest
var envReq EnvsReq
if err := ssh.Unmarshal(req.Payload, &envReq); err != nil {
_ = req.Reply(false, []byte(err.Error()))
return
}

var envs map[string]string
if err := json.Unmarshal(envReq.EnvsJSON, &envs); err != nil {
_ = req.Reply(false, []byte(err.Error()))
return
}

for k, v := range expected {
actual, ok := envs[k]
if !ok {
_ = req.Reply(false, []byte(fmt.Sprintf("expected env %s not present", k)))
return
}

if actual != v {
_ = req.Reply(false, []byte(fmt.Sprintf("expected value %s for env %s, got %s", v, k, actual)))
return
}
}

_ = req.Reply(true, nil)
case i == 1 && req.Type == EnvsRequest: // reject additional EnvsRequest so we test fallbacks
_ = req.Reply(false, nil)
case i >= 2 && i <= len(expected)+2 && req.Type == "env": // accept individual "env" fallbacks.
var e envReqParams
if err := ssh.Unmarshal(req.Payload, &e); err != nil {
_ = req.Reply(false, []byte(err.Error()))
return
}
envReqC <- e
_ = req.Reply(true, nil)
default: // out of order or unexpected message
_ = req.Reply(false, []byte(fmt.Sprintf("unexpected ssh request %s on iteration %d", req.Type, i)))
Comment thread
rosstimothy marked this conversation as resolved.
errChan <- err
return
}
}
}
}()
default:
if err := ch.Reject(ssh.ConnectionFailed, fmt.Sprintf("unexpected channel %s", ch.ChannelType())); err != nil {
errChan <- err
return
}
}
}
}
})

go srv.Run(errChan)

// create a client and open a session
conn, chans, reqs := srv.GetClient(t)
client := NewClient(conn, chans, reqs)
session, err := client.NewSession(ctx)
require.NoError(t, err)

// the first request shouldn't fall back
t.Run("envs set via envs@goteleport.com", func(t *testing.T) {
require.NoError(t, session.SetEnvs(ctx, expected))

select {
case <-envReqC:
t.Fatal("env request received instead of an envs@goteleport.com request")
default:
}
})

// subsequent requests should fall back to standard "env" requests
t.Run("envs set individually", func(t *testing.T) {
require.NoError(t, session.SetEnvs(ctx, expected))

envs := map[string]string{}
envsTimeout := time.NewTimer(3 * time.Second)
defer envsTimeout.Stop()
for i := 0; i < len(expected); i++ {
select {
case env := <-envReqC:
envs[env.Name] = env.Value
case <-envsTimeout.C:
t.Fatalf("Time out waiting for env request %d to be processed", i)
}
}

for k, v := range expected {
actual, ok := envs[k]
require.True(t, ok, "expected env %s to be set", k)
require.Equal(t, v, actual, "expected value %s for env %s, got %s", v, k, actual)
}
})

select {
case err := <-errChan:
require.NoError(t, err)
default:
}
}
87 changes: 76 additions & 11 deletions api/observability/tracing/ssh/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ package ssh

import (
"context"
"encoding/json"
"fmt"
"strings"

"github.com/gravitational/trace"
"go.opentelemetry.io/otel/attribute"
semconv "go.opentelemetry.io/otel/semconv/v1.10.0"
oteltrace "go.opentelemetry.io/otel/trace"
Expand Down Expand Up @@ -51,7 +54,8 @@ func (s *Session) SendRequest(ctx context.Context, name string, wantReply bool,

// no need to wrap payload here, the session's channel wrapper will do it for us
s.wrapper.addContext(ctx, name)
return s.Session.SendRequest(name, wantReply, payload)
ok, err := s.Session.SendRequest(name, wantReply, payload)
return ok, trace.Wrap(err)
}

// Setenv sets an environment variable that will be applied to any
Expand All @@ -72,7 +76,66 @@ func (s *Session) Setenv(ctx context.Context, name, value string) error {
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.Setenv(name, value)
return trace.Wrap(s.Session.Setenv(name, value))
}

// SetEnvs sets environment variables that will be applied to any
// command executed by Shell or Run. If the server does not handle
// [EnvsRequest] requests then the client falls back to sending individual
// "env" requests until all provided environment variables have been set
// or an error was received.
func (s *Session) SetEnvs(ctx context.Context, envs map[string]string) error {
config := tracing.NewConfig(s.wrapper.opts)
ctx, span := config.TracerProvider.Tracer(instrumentationName).Start(
ctx,
"ssh.SetEnvs",
oteltrace.WithSpanKind(oteltrace.SpanKindClient),
oteltrace.WithAttributes(
semconv.RPCServiceKey.String("ssh.Session"),
semconv.RPCMethodKey.String("SendRequest"),
semconv.RPCSystemKey.String("ssh"),
),
)
defer span.End()

if len(envs) == 0 {
return nil
}

// If the server isn't Teleport fallback to individual "env" requests
if !strings.HasPrefix(string(s.wrapper.ServerVersion()), "SSH-2.0-Teleport") {
return trace.Wrap(s.setEnvFallback(ctx, envs))
}

raw, err := json.Marshal(envs)
if err != nil {
return trace.Wrap(err)
}

s.wrapper.addContext(ctx, EnvsRequest)
ok, err := s.Session.SendRequest(EnvsRequest, true, ssh.Marshal(EnvsReq{EnvsJSON: raw}))
if err != nil {
return trace.Wrap(err)
}

// The server does not handle EnvsRequest requests so fall back
// to sending individual requests.
if !ok {
return trace.Wrap(s.setEnvFallback(ctx, envs))
}

return nil
}

// setEnvFallback sends an "env" request for each item in envs.
func (s *Session) setEnvFallback(ctx context.Context, envs map[string]string) error {
for k, v := range envs {
if err := s.Setenv(ctx, k, v); err != nil {
return trace.Wrap(err, "failed to set environment variable %s", k)
}
}

return nil
}

// RequestPty requests the association of a pty with the session on the remote host.
Expand All @@ -95,7 +158,7 @@ func (s *Session) RequestPty(ctx context.Context, term string, h, w int, termmod
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.RequestPty(term, h, w, termmodes)
return trace.Wrap(s.Session.RequestPty(term, h, w, termmodes))
}

// RequestSubsystem requests the association of a subsystem with the session on the remote host.
Expand All @@ -116,7 +179,7 @@ func (s *Session) RequestSubsystem(ctx context.Context, subsystem string) error
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.RequestSubsystem(subsystem)
return trace.Wrap(s.Session.RequestSubsystem(subsystem))
}

// WindowChange informs the remote host about a terminal window dimension change to h rows and w columns.
Expand All @@ -138,7 +201,7 @@ func (s *Session) WindowChange(ctx context.Context, h, w int) error {
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.WindowChange(h, w)
return trace.Wrap(s.Session.WindowChange(h, w))
}

// Signal sends the given signal to the remote process.
Expand All @@ -159,7 +222,7 @@ func (s *Session) Signal(ctx context.Context, sig ssh.Signal) error {
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.Signal(sig)
return trace.Wrap(s.Session.Signal(sig))
}

// Start runs cmd on the remote host. Typically, the remote
Expand All @@ -181,7 +244,7 @@ func (s *Session) Start(ctx context.Context, cmd string) error {
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.Start(cmd)
return trace.Wrap(s.Session.Start(cmd))
}

// Shell starts a login shell on the remote host. A Session only
Expand All @@ -202,7 +265,7 @@ func (s *Session) Shell(ctx context.Context) error {
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.Shell()
return trace.Wrap(s.Session.Shell())
}

// Run runs cmd on the remote host. Typically, the remote
Expand Down Expand Up @@ -234,7 +297,7 @@ func (s *Session) Run(ctx context.Context, cmd string) error {
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.Run(cmd)
return trace.Wrap(s.Session.Run(cmd))
}

// Output runs cmd on the remote host and returns its standard output.
Expand All @@ -254,7 +317,8 @@ func (s *Session) Output(ctx context.Context, cmd string) ([]byte, error) {
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.Output(cmd)
output, err := s.Session.Output(cmd)
return output, trace.Wrap(err)
}

// CombinedOutput runs cmd on the remote host and returns its combined
Expand All @@ -275,5 +339,6 @@ func (s *Session) CombinedOutput(ctx context.Context, cmd string) ([]byte, error
defer span.End()

s.wrapper.addContext(ctx, request)
return s.Session.CombinedOutput(cmd)
output, err := s.Session.CombinedOutput(cmd)
return output, trace.Wrap(err)
}
13 changes: 13 additions & 0 deletions api/observability/tracing/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ import (
)

const (
// EnvsRequest sets multiple environment variables that will be applied to any
// command executed by Shell or Run.
Comment thread
rosstimothy marked this conversation as resolved.
// See [EnvsReq] for the corresponding payload.
EnvsRequest = "envs@goteleport.com"

// TracingRequest is sent by clients to server to pass along tracing context.
TracingRequest = "tracing@goteleport.com"

Expand All @@ -45,6 +50,14 @@ const (
instrumentationName = "otelssh"
)

// EnvsReq contains json marshaled key:value pairs sent as the
// payload for an [EnvsRequest].
type EnvsReq struct {
// EnvsJSON is a json marshaled map[string]string containing
// environment variables.
EnvsJSON []byte `json:"envs"`
}

// ContextFromRequest extracts any tracing data provided via an Envelope
// in the ssh.Request payload. If the payload contains an Envelope, then
// the context returned will have tracing data populated from the remote
Expand Down
22 changes: 11 additions & 11 deletions lib/client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,22 +225,22 @@ func (ns *NodeSession) createServerSession(ctx context.Context) (*tracessh.Sessi
return nil, trace.Wrap(err)
}

envs := map[string]string{}

// pass language info into the remote session.
evarsToPass := []string{"LANG", "LANGUAGE"}
for _, evar := range evarsToPass {
if value := os.Getenv(evar); value != "" {
err = sess.Setenv(ctx, evar, value)
if err != nil {
log.Warn(err)
}
langVars := []string{"LANG", "LANGUAGE"}
for _, env := range langVars {
if value := os.Getenv(env); value != "" {
envs[env] = value
}
}
// pass environment variables set by client
for key, val := range ns.env {
err = sess.Setenv(ctx, key, val)
if err != nil {
log.Warn(err)
}
envs[key] = val
}

if err := sess.SetEnvs(ctx, envs); err != nil {
log.Warn(err)
Comment thread
codingllama marked this conversation as resolved.
}

// if agent forwarding was requested (and we have a agent to forward),
Expand Down
Loading