diff --git a/api/observability/tracing/tracing.go b/api/observability/tracing/tracing.go index a8fcece011801..c5613b3f95b69 100644 --- a/api/observability/tracing/tracing.go +++ b/api/observability/tracing/tracing.go @@ -26,6 +26,10 @@ import ( // PropagationContext contains tracing information to be passed across service boundaries type PropagationContext map[string]string +// TraceParent is the name of the header or query parameter that contains +// tracing context across service boundaries. +const TraceParent = "traceparent" + // PropagationContextFromContext creates a PropagationContext from the given context.Context. If the context // does not contain any tracing information, the PropagationContext will be empty. func PropagationContextFromContext(ctx context.Context, opts ...Option) PropagationContext { diff --git a/lib/httplib/httplib.go b/lib/httplib/httplib.go index 54872ce972509..fa3be1fa940bd 100644 --- a/lib/httplib/httplib.go +++ b/lib/httplib/httplib.go @@ -62,7 +62,27 @@ func MakeHandler(fn HandlerFunc) httprouter.Handle { // MakeTracingHandler returns a new httprouter.Handle func that wraps the provided handler func // with one that will add a tracing span for each request. func MakeTracingHandler(h http.Handler, component string) http.Handler { - return otelhttp.NewHandler(h, component, otelhttp.WithSpanNameFormatter(tracing.HTTPHandlerFormatter)) + // Wrap the provided handler with one that will inject + // any propagated tracing context provided via a query parameter + // if there isn't already a header containing tracing context. + // This is required for scenarios using web sockets as headers + // cannot be modified to inject the tracing context. + handler := func(w http.ResponseWriter, r *http.Request) { + // ensure headers have priority over query parameters + if r.Header.Get(tracing.TraceParent) != "" { + h.ServeHTTP(w, r) + return + } + + traceParent := r.URL.Query()[tracing.TraceParent] + if len(traceParent) > 0 { + r.Header.Add(tracing.TraceParent, traceParent[0]) + } + + h.ServeHTTP(w, r) + } + + return otelhttp.NewHandler(http.HandlerFunc(handler), component, otelhttp.WithSpanNameFormatter(tracing.HTTPHandlerFormatter)) } // MakeHandlerWithErrorWriter returns a httprouter.Handle from the HandlerFunc, diff --git a/lib/httplib/httplib_test.go b/lib/httplib/httplib_test.go index c51bee90a9efb..f5534f9efadea 100644 --- a/lib/httplib/httplib_test.go +++ b/lib/httplib/httplib_test.go @@ -29,6 +29,9 @@ import ( "github.com/gravitational/roundtrip" "github.com/julienschmidt/httprouter" "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/observability/tracing" ) type netError struct{} @@ -186,3 +189,77 @@ func TestReadJSON_ContentType(t *testing.T) { }) } } + +func TestMakeTracingHandler(t *testing.T) { + t.Parallel() + + newRequest := func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "", nil) + require.NoError(t, err) + + return req + } + + cases := []struct { + name string + req func(t *testing.T) *http.Request + headerAssertion func(t *testing.T, req *http.Request) + }{ + { + name: "no tracing context provided", + req: newRequest, + headerAssertion: func(t *testing.T, req *http.Request) { + require.Empty(t, req.Header.Get(tracing.TraceParent)) + }, + }, + { + name: "tracing context provided via header", + req: func(t *testing.T) *http.Request { + req := newRequest(t) + req.Header.Add(tracing.TraceParent, "test") + return req + }, + headerAssertion: func(t *testing.T, req *http.Request) { + require.Equal(t, "test", req.Header.Get(tracing.TraceParent)) + }, + }, + { + name: "tracing context provided via parameter", + req: func(t *testing.T) *http.Request { + req := newRequest(t) + q := req.URL.Query() + q.Set(tracing.TraceParent, "test") + req.URL.RawQuery = q.Encode() + return req + }, + headerAssertion: func(t *testing.T, req *http.Request) { + require.Equal(t, "test", req.Header.Get(tracing.TraceParent)) + }, + }, + { + name: "header has priority", + req: func(t *testing.T) *http.Request { + req := newRequest(t) + q := req.URL.Query() + req.Header.Add(tracing.TraceParent, "header") + q.Set(tracing.TraceParent, "parameter") + req.URL.RawQuery = q.Encode() + return req + }, + headerAssertion: func(t *testing.T, req *http.Request) { + require.Equal(t, "header", req.Header.Get(tracing.TraceParent)) + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + handler := MakeTracingHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + tt.headerAssertion(t, r) + }), teleport.ComponentProxy) + + handler.ServeHTTP(httptest.NewRecorder(), tt.req(t)) + }) + } + +} diff --git a/lib/service/service.go b/lib/service/service.go index 01b3735e7293e..c3e06dbd3aa40 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -3477,6 +3477,22 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { proxyKubeAddr = cfg.Proxy.Kube.PublicAddrs[0] } + traceClt := tracing.NewNoopClient() + if cfg.Tracing.Enabled { + traceConf, err := process.Config.Tracing.Config() + if err != nil { + return trace.Wrap(err) + } + traceConf.Logger = process.log.WithField(trace.Component, teleport.ComponentTracing) + + clt, err := tracing.NewStartedClient(process.ExitContext(), *traceConf) + if err != nil { + return trace.Wrap(err) + } + + traceClt = clt + } + webConfig := web.Config{ Proxy: tsrv, AuthServers: cfg.AuthServerAddresses()[0], @@ -3498,6 +3514,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error { PublicProxyAddr: process.proxyPublicAddr().Addr, ALPNHandler: alpnHandlerForWeb.HandleConnection, ProxyKubeAddr: proxyKubeAddr, + TraceClient: traceClt, } webHandler, err = web.NewHandler(webConfig) if err != nil { diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index b79244c39cc79..602511920a6a5 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -42,15 +42,20 @@ import ( "github.com/julienschmidt/httprouter" lemma_secret "github.com/mailgun/lemma/secret" "github.com/sirupsen/logrus" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace" + oteltrace "go.opentelemetry.io/otel/trace" + tracepb "go.opentelemetry.io/proto/otlp/trace/v1" "golang.org/x/crypto/ssh" "golang.org/x/exp/slices" "golang.org/x/mod/semver" + "google.golang.org/protobuf/encoding/protojson" "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" + apitracing "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/installers" @@ -65,7 +70,6 @@ import ( "github.com/gravitational/teleport/lib/httplib/csrf" "github.com/gravitational/teleport/lib/jwt" "github.com/gravitational/teleport/lib/limiter" - "github.com/gravitational/teleport/lib/observability/tracing" "github.com/gravitational/teleport/lib/plugin" "github.com/gravitational/teleport/lib/reversetunnel" "github.com/gravitational/teleport/lib/secret" @@ -206,6 +210,9 @@ type Config struct { // ALPNHandler is the ALPN connection handler for handling upgraded ALPN // connection through a HTTP upgrade call. ALPNHandler ConnectionHandler + + // TraceClient is used to forward spans to the upstream collector for the UI + TraceClient otlptrace.Client } type APIHandler struct { @@ -481,6 +488,9 @@ func (h *Handler) bindDefaultEndpoints(challengeLimiter *limiter.RateLimiter) { // Migrated this endpoint to /webapi/sessions/web below. h.POST("/webapi/sessions", httplib.WithCSRFProtection(h.createWebSession)) + // Forwards traces to the configured upstream collector + h.POST("/webapi/traces", h.WithAuth(h.traces)) + // Web sessions h.POST("/webapi/sessions/web", httplib.WithCSRFProtection(h.createWebSession)) h.POST("/webapi/sessions/app", h.WithAuth(h.createAppSession)) @@ -911,6 +921,70 @@ func getAuthSettings(ctx context.Context, authClient auth.ClientI) (webclient.Au return as, nil } +// traces forwards spans from the web ui to the upstream collector configured for the proxy. If tracing is +// disabled then the forwarding is a noop. +func (h *Handler) traces(w http.ResponseWriter, r *http.Request, _ httprouter.Params, _ *SessionContext) (interface{}, error) { + body, err := io.ReadAll(io.LimitReader(r.Body, teleport.MaxHTTPRequestSize)) + if err != nil { + h.log.WithError(err).Error("Failed to read traces request") + w.WriteHeader(http.StatusBadRequest) + return nil, nil + } + + if err := r.Body.Close(); err != nil { + h.log.WithError(err).Warn("Failed to close traces request body") + } + + var data tracepb.TracesData + if err := protojson.Unmarshal(body, &data); err != nil { + h.log.WithError(err).Error("Failed to unmarshal traces request") + w.WriteHeader(http.StatusBadRequest) + return nil, nil + } + + if len(data.ResourceSpans) == 0 { + w.WriteHeader(http.StatusBadRequest) + return nil, nil + } + + // Unmarshalling of TraceId, SpanId, and ParentSpanId might all yield incorrect values. The raw values from + // OpenTelemetry-js are hex encoded, but the unmarshal call above will decode them as base64. + // In order to ensure the ids are in the right format and won't be rejected by the upstream collector + // we attempt to convert them back into the base64 and then hex decode them. + for _, resourceSpan := range data.ResourceSpans { + for _, scopeSpan := range resourceSpan.ScopeSpans { + for _, span := range scopeSpan.Spans { + + // attempt to convert the trace id to the right format + if tid, err := oteltrace.TraceIDFromHex(base64.StdEncoding.EncodeToString(span.TraceId)); err == nil { + span.TraceId = tid[:] + } + + // attempt to convert the span id to the right format + if sid, err := oteltrace.SpanIDFromHex(base64.StdEncoding.EncodeToString(span.SpanId)); err == nil { + span.SpanId = sid[:] + } + + // attempt to convert the parent span id to the right format + if len(span.ParentSpanId) > 0 { + if psid, err := oteltrace.SpanIDFromHex(base64.StdEncoding.EncodeToString(span.ParentSpanId)); err == nil { + span.ParentSpanId = psid[:] + } + } + } + } + } + + go func() { + if err := h.cfg.TraceClient.UploadTraces(r.Context(), data.ResourceSpans); err != nil { + h.log.WithError(err).Error("Failed to upload traces") + } + }() + + w.WriteHeader(http.StatusOK) + return nil, nil +} + func (h *Handler) ping(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { var err error authSettings, err := getAuthSettings(r.Context(), h.cfg.ProxyClient) @@ -2151,7 +2225,7 @@ func (h *Handler) siteNodeConnect( return nil, trace.Wrap(err) } - netConfig, err := authAccessPoint.GetClusterNetworkingConfig(h.cfg.Context) + netConfig, err := authAccessPoint.GetClusterNetworkingConfig(r.Context()) if err != nil { h.log.WithError(err).Debug("Unable to fetch cluster networking config.") return nil, trace.Wrap(err) @@ -2175,7 +2249,7 @@ func (h *Handler) siteNodeConnect( // start the websocket session with a web-based terminal: h.log.Infof("Getting terminal to %#v.", req) - term.Serve(w, r) + httplib.MakeTracingHandler(term, teleport.ComponentProxy).ServeHTTP(w, r) return nil, nil } @@ -3106,7 +3180,7 @@ func makeTeleportClientConfig(ctx context.Context, sesCtx *SessionContext) (*cli DefaultPrincipal: cert.ValidPrincipals[0], HostKeyCallback: callback, TLSRoutingEnabled: proxyListenerMode == types.ProxyListenerMode_Multiplex, - Tracer: tracing.NoopProvider().Tracer("test"), + Tracer: apitracing.DefaultProvider().Tracer("webterminal"), } return config, nil diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 645968c846383..7e66fd146b96e 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -26,6 +26,7 @@ import ( "crypto/tls" "crypto/x509" "encoding/base32" + "encoding/base64" "encoding/hex" "encoding/json" "encoding/pem" @@ -59,9 +60,15 @@ import ( "github.com/pquerna/otp/totp" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" + commonv1 "go.opentelemetry.io/proto/otlp/common/v1" + resourcev1 "go.opentelemetry.io/proto/otlp/resource/v1" + otlp "go.opentelemetry.io/proto/otlp/trace/v1" + tracepb "go.opentelemetry.io/proto/otlp/trace/v1" "golang.org/x/crypto/ssh" "golang.org/x/exp/slices" "golang.org/x/text/encoding/unicode" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/testing/protocmp" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" @@ -6508,3 +6515,231 @@ func init() { &metav1.Status{}, ) } + +// TestForwardingTraces checks that the userContext includes the ID of the +// access request after it has been consumed and the web session has been renewed. +func TestForwardingTraces(t *testing.T) { + t.Parallel() + + env := newWebPack(t, 1) + p := env.proxies[0] + + newRequest := func(t *testing.T) *http.Request { + req, err := http.NewRequest(http.MethodGet, "", nil) + require.NoError(t, err) + + return req + } + + // Span captured from the UI which was marshaled by opentelemetry-js. + const rawSpan = `{"resourceSpans":[{"resource":{"attributes":[{"key":"service.name","value":{"stringValue":"web-ui"}},{"key":"telemetry.sdk.language","value":{"stringValue":"webjs"}},{"key":"telemetry.sdk.name","value":{"stringValue":"opentelemetry"}},{"key":"telemetry.sdk.version","value":{"stringValue":"1.7.0"}},{"key":"service.version","value":{"stringValue":"0.1.0"}}],"droppedAttributesCount":0},"scopeSpans":[{"scope":{"name":"@opentelemetry/instrumentation-fetch","version":"0.33.0"},"spans":[{"traceId":"255c8d876e7dbf3707ee8451ad518652","spanId":"d9edec516e598d8c","name":"HTTP GET","kind":3,"startTimeUnixNano":1668606426497000000,"endTimeUnixNano":1668502943215499800,"attributes":[{"key":"component","value":{"stringValue":"fetch"}},{"key":"http.method","value":{"stringValue":"GET"}},{"key":"http.url","value":{"stringValue":"https://proxy.example.com/v1/webapi/user/status"}},{"key":"http.status_code","value":{"intValue":0}},{"key":"http.status_text","value":{"stringValue":"Failed to fetch"}},{"key":"http.host","value":{"stringValue":"proxy.example.com"}},{"key":"http.scheme","value":{"stringValue":"https"}},{"key":"http.user_agent","value":{"stringValue":"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0 Safari/537.36 "}},{"key":"http.response_content_length","value":{"intValue":0}}],"droppedAttributesCount":0,"events":[{"attributes":[],"name":"fetchStart","timeUnixNano":1668502943210900000,"droppedAttributesCount":0},{"attributes":[],"name":"domainLookupStart","timeUnixNano":1668502687491499800,"droppedAttributesCount":0},{"attributes":[],"name":"domainLookupEnd","timeUnixNano":1668502687491499800,"droppedAttributesCount":0},{"attributes":[],"name":"connectStart","timeUnixNano":1668502687491499800,"droppedAttributesCount":0},{"attributes":[],"name":"secureConnectionStart","timeUnixNano":1668502687491499800,"droppedAttributesCount":0},{"attributes":[],"name":"connectEnd","timeUnixNano":1668502687491499800,"droppedAttributesCount":0},{"attributes":[],"name":"requestStart","timeUnixNano":1668502687491499800,"droppedAttributesCount":0},{"attributes":[],"name":"responseStart","timeUnixNano":1668502687491499800,"droppedAttributesCount":0},{"attributes":[],"name":"responseEnd","timeUnixNano":1668502943215100000,"droppedAttributesCount":0}],"droppedEventsCount":0,"status":{"code":0},"links":[],"droppedLinksCount":0}]}]}]}` + + // dummy span with arbitrary data, needed to be able to protojson.Marshal in tests + span := &tracepb.TracesData{ + ResourceSpans: []*tracepb.ResourceSpans{ + { + Resource: &resourcev1.Resource{ + Attributes: []*commonv1.KeyValue{ + { + Key: "test", + Value: &commonv1.AnyValue{ + Value: &commonv1.AnyValue_IntValue{ + IntValue: 0, + }, + }, + }, + }, + }, + ScopeSpans: []*tracepb.ScopeSpans{ + { + Spans: []*tracepb.Span{ + { + TraceId: []byte{1, 2, 3, 4}, + SpanId: []byte{5, 6, 7, 8}, + TraceState: "", + ParentSpanId: []byte{9, 10, 11, 12}, + Name: "test", + Kind: tracepb.Span_SPAN_KIND_CLIENT, + StartTimeUnixNano: uint64(time.Now().Add(-1 * time.Minute).Unix()), + EndTimeUnixNano: uint64(time.Now().Unix()), + Attributes: []*commonv1.KeyValue{ + { + Key: "test", + Value: &commonv1.AnyValue{ + Value: &commonv1.AnyValue_IntValue{ + IntValue: 11, + }, + }, + }, + }, + Status: &tracepb.Status{ + Message: "success!", + Code: tracepb.Status_STATUS_CODE_OK, + }, + }, + }, + }, + }, + }, + }, + } + + cases := []struct { + name string + req func(t *testing.T) *http.Request + assertion func(t *testing.T, spans []*otlp.ResourceSpans, err error, code int) + }{ + { + name: "no data", + req: func(t *testing.T) *http.Request { + r := newRequest(t) + r.Body = io.NopCloser(&bytes.Buffer{}) + return r + }, + assertion: func(t *testing.T, spans []*tracepb.ResourceSpans, err error, code int) { + require.NoError(t, err) + require.Equal(t, http.StatusBadRequest, code) + require.Empty(t, spans) + }, + }, + { + name: "invalid data", + req: func(t *testing.T) *http.Request { + r := newRequest(t) + r.Body = io.NopCloser(strings.NewReader(`{"test": "abc"}`)) + return r + }, + assertion: func(t *testing.T, spans []*tracepb.ResourceSpans, err error, code int) { + require.NoError(t, err) + require.Equal(t, http.StatusBadRequest, code) + require.Empty(t, spans) + }, + }, + { + name: "no traces", + req: func(t *testing.T) *http.Request { + r := newRequest(t) + + raw, err := protojson.Marshal(&tracepb.ResourceSpans{}) + require.NoError(t, err) + r.Body = io.NopCloser(bytes.NewBuffer(raw)) + + return r + }, + assertion: func(t *testing.T, spans []*tracepb.ResourceSpans, err error, code int) { + require.NoError(t, err) + require.Equal(t, http.StatusBadRequest, code) + require.Empty(t, spans) + }, + }, + { + name: "traces with base64 encoded ids", + req: func(t *testing.T) *http.Request { + r := newRequest(t) + + // Since the id fields of the span are all []byte, + // protojson will marshal them into base64 + raw, err := protojson.Marshal(span) + require.NoError(t, err) + r.Body = io.NopCloser(bytes.NewBuffer(raw)) + + return r + }, + assertion: func(t *testing.T, spans []*tracepb.ResourceSpans, err error, code int) { + require.NoError(t, err) + require.Equal(t, http.StatusOK, code) + require.Len(t, spans, 1) + require.Empty(t, cmp.Diff(span.ResourceSpans[0], spans[0], protocmp.Transform())) + }, + }, + { + name: "traces with hex encoded ids", + req: func(t *testing.T) *http.Request { + r := newRequest(t) + + // The id fields are hex encoded instead of base64 encoded + // by opentelemetry-js for the rawSpan + r.Body = io.NopCloser(strings.NewReader(rawSpan)) + + return r + }, + assertion: func(t *testing.T, spans []*otlp.ResourceSpans, err error, code int) { + require.NoError(t, err) + require.Equal(t, http.StatusOK, code) + require.Len(t, spans, 1) + + var data tracepb.TracesData + require.NoError(t, protojson.Unmarshal([]byte(rawSpan), &data)) + + // compare the spans, but ignore the ids since we know that the rawSpan + // has hex encoded ids and protojson.Unmarshal will give us an invalid value + require.Empty(t, cmp.Diff(data.ResourceSpans[0], spans[0], protocmp.Transform(), protocmp.IgnoreFields(&otlp.Span{}, "span_id", "trace_id"))) + + // compare the ids separately + sid1 := spans[0].ScopeSpans[0].Spans[0].SpanId + tid1 := spans[0].ScopeSpans[0].Spans[0].TraceId + + sid2 := data.ResourceSpans[0].ScopeSpans[0].Spans[0].SpanId + tid2 := data.ResourceSpans[0].ScopeSpans[0].Spans[0].TraceId + + require.Equal(t, hex.EncodeToString(sid1), base64.StdEncoding.EncodeToString(sid2)) + require.Equal(t, hex.EncodeToString(tid1), base64.StdEncoding.EncodeToString(tid2)) + }, + }, + } + + // NOTE: resetting the tracing client prevents + // the test cases from running in parallel + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + clt := &mockTraceClient{ + uploadReceived: make(chan struct{}), + } + p.handler.handler.cfg.TraceClient = clt + + recorder := httptest.NewRecorder() + + // use the handler directly because there is no easy way to pipe in our tracing + // data using the pack client in a format that would match the ui. + _, err := p.handler.handler.traces(recorder, tt.req(t), nil, nil) + + // if traces weren't uploaded perform the assertion + // without waiting for traces to be forwarded + if err != nil || recorder.Code != http.StatusOK { + tt.assertion(t, clt.spans, err, recorder.Code) + return + } + + // traces are forwarded in a goroutine, wait for them + // to be received by the trace client before doing the + // assertion + select { + case <-clt.uploadReceived: + case <-time.After(10 * time.Second): + t.Fatal("Timed out waiting for traces to be uploaded") + } + + tt.assertion(t, clt.spans, err, recorder.Code) + }) + } +} + +type mockTraceClient struct { + uploadError error + uploadReceived chan struct{} + spans []*otlp.ResourceSpans +} + +func (m *mockTraceClient) Start(ctx context.Context) error { + return nil +} + +func (m *mockTraceClient) Stop(ctx context.Context) error { + return nil +} + +func (m *mockTraceClient) UploadTraces(ctx context.Context, protoSpans []*otlp.ResourceSpans) error { + m.spans = append(m.spans, protoSpans...) + m.uploadReceived <- struct{}{} + return m.uploadError +} diff --git a/lib/web/terminal.go b/lib/web/terminal.go index c47e852db04fe..d1e1835058e53 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -30,12 +30,14 @@ import ( "github.com/gorilla/websocket" "github.com/gravitational/trace" "github.com/sirupsen/logrus" + oteltrace "go.opentelemetry.io/otel/trace" "golang.org/x/crypto/ssh" "golang.org/x/text/encoding" "golang.org/x/text/encoding/unicode" "github.com/gravitational/teleport" authproto "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/observability/tracing" tracessh "github.com/gravitational/teleport/api/observability/tracing/ssh" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" @@ -91,6 +93,9 @@ type AuthProvider interface { // NewTerminal creates a web-based terminal based on WebSockets and returns a // new TerminalHandler. func NewTerminal(ctx context.Context, req TerminalRequest, authProvider AuthProvider, sessCtx *SessionContext) (*TerminalHandler, error) { + ctx, span := tracing.DefaultProvider().Tracer("terminal").Start(ctx, "terminal/NewTerminal") + defer span.End() + // Make sure whatever session is requested is a valid session. _, err := session.ParseID(string(req.SessionID)) if err != nil { @@ -197,10 +202,10 @@ type TerminalHandler struct { join bool } -// Serve builds a connect to the remote node and then pumps back two types of +// ServeHTTP builds a connection to the remote node and then pumps back two types of // events: raw input/output events for what's happening on the terminal itself // and audit log events relevant to this session. -func (t *TerminalHandler) Serve(w http.ResponseWriter, r *http.Request) { +func (t *TerminalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // This allows closing of the websocket if the user logs out before exiting // the session. t.ctx.AddClosers(t) @@ -272,8 +277,10 @@ func (t *TerminalHandler) startPingLoop(ws *websocket.Conn) { func (t *TerminalHandler) handler(ws *websocket.Conn, r *http.Request) { defer ws.Close() - // Create a context for signaling when the terminal session is over. - t.terminalContext, t.terminalCancel = context.WithCancel(context.Background()) + // Create a context for signaling when the terminal session is over and + // link it first with the trace context from the request context + tctx := oteltrace.ContextWithRemoteSpanContext(context.Background(), oteltrace.SpanContextFromContext(r.Context())) + t.terminalContext, t.terminalCancel = context.WithCancel(tctx) // Create a Teleport client, if not able to, show the reason to the user in // the terminal. @@ -309,7 +316,10 @@ func (t *TerminalHandler) handler(ws *websocket.Conn, r *http.Request) { // makeClient builds a *client.TeleportClient for the connection. func (t *TerminalHandler) makeClient(ws *websocket.Conn, r *http.Request) (*client.TeleportClient, error) { - clientConfig, err := makeTeleportClientConfig(r.Context(), t.ctx) + ctx, span := tracing.DefaultProvider().Tracer("terminal").Start(r.Context(), "terminal/makeClient") + defer span.End() + + clientConfig, err := makeTeleportClientConfig(ctx, t.ctx) if err != nil { return nil, trace.Wrap(err) } @@ -337,6 +347,7 @@ func (t *TerminalHandler) makeClient(ws *websocket.Conn, r *http.Request) (*clie clientConfig.HostPort = t.hostPort clientConfig.Env = map[string]string{sshutils.SessionEnvVar: string(t.params.SessionID)} clientConfig.ClientAddr = r.RemoteAddr + clientConfig.Tracer = tracing.DefaultProvider().Tracer("TerminalHandler") if len(t.params.InteractiveCommand) > 0 { clientConfig.Interactive = true @@ -357,15 +368,18 @@ func (t *TerminalHandler) makeClient(ws *websocket.Conn, r *http.Request) (*clie return false, nil } - if err := t.issueSessionMFACerts(tc, ws); err != nil { + if err := t.issueSessionMFACerts(ctx, tc, ws); err != nil { return nil, trace.Wrap(err) } return tc, nil } -func (t *TerminalHandler) issueSessionMFACerts(tc *client.TeleportClient, ws *websocket.Conn) error { - pc, err := tc.ConnectToProxy(t.terminalContext) +func (t *TerminalHandler) issueSessionMFACerts(ctx context.Context, tc *client.TeleportClient, ws *websocket.Conn) error { + ctx, span := tracing.DefaultProvider().Tracer("terminal").Start(ctx, "terminal/issueSessionMFACerts") + defer span.End() + + pc, err := tc.ConnectToProxy(ctx) if err != nil { return trace.Wrap(err) } @@ -376,7 +390,7 @@ func (t *TerminalHandler) issueSessionMFACerts(tc *client.TeleportClient, ws *we return trace.Wrap(err) } - key, err := pc.IssueUserCertsWithMFA(t.terminalContext, client.ReissueParams{ + key, err := pc.IssueUserCertsWithMFA(ctx, client.ReissueParams{ RouteToCluster: t.params.Cluster, NodeName: t.params.Server, ExistingCreds: &client.Key{ @@ -694,7 +708,7 @@ func (t *TerminalHandler) read(out []byte, ws *websocket.Conn) (n int, err error // Send the window change request in a goroutine so reads are not blocked // by network connectivity issues. - go t.windowChange(context.TODO(), params) + go t.windowChange(t.terminalContext, params) return 0, nil default: