diff --git a/internal/extension/extension.go b/internal/extension/extension.go index 33befb6d..6e0b890c 100644 --- a/internal/extension/extension.go +++ b/internal/extension/extension.go @@ -109,7 +109,10 @@ func (em *ExtensionManager) SendStartInvocationRequest(ctx context.Context, even if traceId != "" { ctx = context.WithValue(ctx, DdTraceId, traceId) } - parentId := response.Header.Get(string(DdParentId)) + parentId := traceId + if pid := response.Header.Get(string(DdParentId)); pid != "" { + parentId = pid + } if parentId != "" { ctx = context.WithValue(ctx, DdParentId, parentId) } diff --git a/internal/extension/extension_test.go b/internal/extension/extension_test.go index 3c960f08..d1c45036 100644 --- a/internal/extension/extension_test.go +++ b/internal/extension/extension_test.go @@ -174,6 +174,29 @@ func TestExtensionStartInvokeWithTraceContext(t *testing.T) { assert.Equal(t, mockSamplingPriority, samplingPriority) } +func TestExtensionStartInvokeWithTraceContextNoParentID(t *testing.T) { + headers := http.Header{} + headers.Set(string(DdTraceId), mockTraceId) + headers.Set(string(DdSamplingPriority), mockSamplingPriority) + + em := &ExtensionManager{ + startInvocationUrl: startInvocationUrl, + httpClient: &ClientSuccessStartInvoke{ + headers: headers, + }, + } + ctx := em.SendStartInvocationRequest(context.TODO(), []byte{}) + traceId := ctx.Value(DdTraceId) + parentId := ctx.Value(DdParentId) + samplingPriority := ctx.Value(DdSamplingPriority) + err := em.Flush() + + assert.Nil(t, err) + assert.Equal(t, mockTraceId, traceId) + assert.Equal(t, mockTraceId, parentId) + assert.Equal(t, mockSamplingPriority, samplingPriority) +} + func TestExtensionEndInvocation(t *testing.T) { em := &ExtensionManager{ endInvocationUrl: endInvocationUrl, diff --git a/internal/trace/context.go b/internal/trace/context.go index 7afc7784..6ed6c05b 100644 --- a/internal/trace/context.go +++ b/internal/trace/context.go @@ -17,6 +17,7 @@ import ( "strconv" "strings" + "github.com/DataDog/datadog-lambda-go/internal/extension" "github.com/DataDog/datadog-lambda-go/internal/logger" "github.com/aws/aws-xray-sdk-go/header" "github.com/aws/aws-xray-sdk-go/xray" @@ -47,7 +48,7 @@ var DefaultTraceExtractor = getHeadersFromEventHeaders // contextWithRootTraceContext uses the incoming event and context object payloads to determine // the root TraceContext and then adds that TraceContext to the context object. func contextWithRootTraceContext(ctx context.Context, ev json.RawMessage, mergeXrayTraces bool, extractor ContextExtractor) (context.Context, error) { - datadogTraceContext, gotDatadogTraceContext := getTraceContext(extractor(ctx, ev)) + datadogTraceContext, gotDatadogTraceContext := getTraceContext(ctx, extractor(ctx, ev)) xrayTraceContext, errGettingXrayContext := convertXrayTraceContextFromLambdaContext(ctx) if errGettingXrayContext != nil { @@ -126,21 +127,36 @@ func createDummySubsegmentForXrayConverter(ctx context.Context, traceCtx TraceCo return nil } -func getTraceContext(context map[string]string) (TraceContext, bool) { +func getTraceContext(ctx context.Context, headers map[string]string) (TraceContext, bool) { tc := TraceContext{} - traceID, ok := context[traceIDHeader] - if !ok { + traceID := headers[traceIDHeader] + if traceID == "" { + if val, ok := ctx.Value(extension.DdTraceId).(string); ok { + traceID = val + } + } + if traceID == "" { return tc, false } - parentID, ok := context[parentIDHeader] - if !ok { + parentID := headers[parentIDHeader] + if parentID == "" { + if val, ok := ctx.Value(extension.DdParentId).(string); ok { + parentID = val + } + } + if parentID == "" { return tc, false } - samplingPriority, ok := context[samplingPriorityHeader] - if !ok { + samplingPriority := headers[samplingPriorityHeader] + if samplingPriority == "" { + if val, ok := ctx.Value(extension.DdSamplingPriority).(string); ok { + samplingPriority = val + } + } + if samplingPriority == "" { samplingPriority = "1" //sampler-keep } diff --git a/internal/trace/context_test.go b/internal/trace/context_test.go index 3ef7fd70..303e0128 100644 --- a/internal/trace/context_test.go +++ b/internal/trace/context_test.go @@ -14,10 +14,9 @@ import ( "io/ioutil" "testing" + "github.com/DataDog/datadog-lambda-go/internal/extension" "github.com/aws/aws-xray-sdk-go/header" - "github.com/aws/aws-xray-sdk-go/xray" - "github.com/stretchr/testify/assert" ) @@ -45,6 +44,20 @@ func mockLambdaXRayTraceContext(ctx context.Context, traceID, parentID string, s return context.WithValue(ctx, xray.LambdaTraceHeaderKey, headerString) } +func mockTraceContext(traceID, parentID, samplingPriority string) context.Context { + ctx := context.Background() + if traceID != "" { + ctx = context.WithValue(ctx, extension.DdTraceId, traceID) + } + if parentID != "" { + ctx = context.WithValue(ctx, extension.DdParentId, parentID) + } + if samplingPriority != "" { + ctx = context.WithValue(ctx, extension.DdSamplingPriority, samplingPriority) + } + return ctx +} + func loadRawJSON(t *testing.T, filename string) *json.RawMessage { bytes, err := ioutil.ReadFile(filename) if err != nil { @@ -60,7 +73,7 @@ func TestGetDatadogTraceContextForTraceMetadataNonProxyEvent(t *testing.T) { ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true) ev := loadRawJSON(t, "../testdata/apig-event-with-headers.json") - headers, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev)) + headers, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev)) assert.True(t, ok) expected := TraceContext{ @@ -75,7 +88,7 @@ func TestGetDatadogTraceContextForTraceMetadataWithMixedCaseHeaders(t *testing.T ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true) ev := loadRawJSON(t, "../testdata/non-proxy-with-mixed-case-headers.json") - headers, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev)) + headers, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev)) assert.True(t, ok) expected := TraceContext{ @@ -90,7 +103,7 @@ func TestGetDatadogTraceContextForTraceMetadataWithMissingSamplingPriority(t *te ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true) ev := loadRawJSON(t, "../testdata/non-proxy-with-missing-sampling-priority.json") - headers, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev)) + headers, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev)) assert.True(t, ok) expected := TraceContext{ @@ -105,7 +118,7 @@ func TestGetDatadogTraceContextForInvalidData(t *testing.T) { ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true) ev := loadRawJSON(t, "../testdata/invalid.json") - _, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev)) + _, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev)) assert.False(t, ok) } @@ -113,10 +126,67 @@ func TestGetDatadogTraceContextForMissingData(t *testing.T) { ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true) ev := loadRawJSON(t, "../testdata/non-proxy-no-headers.json") - _, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev)) + _, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev)) assert.False(t, ok) } +func TestGetDatadogTraceContextFromContextObject(t *testing.T) { + testcases := []struct { + traceID string + parentID string + samplingPriority string + expectTC TraceContext + expectOk bool + }{ + { + "trace", + "parent", + "sampling", + TraceContext{ + "x-datadog-trace-id": "trace", + "x-datadog-parent-id": "parent", + "x-datadog-sampling-priority": "sampling", + }, + true, + }, + { + "", + "parent", + "sampling", + TraceContext{}, + false, + }, + { + "trace", + "", + "sampling", + TraceContext{}, + false, + }, + { + "trace", + "parent", + "", + TraceContext{ + "x-datadog-trace-id": "trace", + "x-datadog-parent-id": "parent", + "x-datadog-sampling-priority": "1", + }, + true, + }, + } + + ev := loadRawJSON(t, "../testdata/non-proxy-no-headers.json") + for _, test := range testcases { + t.Run(test.traceID+test.parentID+test.samplingPriority, func(t *testing.T) { + ctx := mockTraceContext(test.traceID, test.parentID, test.samplingPriority) + tc, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev)) + assert.Equal(t, test.expectTC, tc) + assert.Equal(t, test.expectOk, ok) + }) + } +} + func TestConvertXRayTraceID(t *testing.T) { output, err := convertXRayTraceIDToDatadogTraceID(mockXRayTraceID) assert.NoError(t, err) diff --git a/internal/trace/listener.go b/internal/trace/listener.go index 7725295c..7deddbd4 100644 --- a/internal/trace/listener.go +++ b/internal/trace/listener.go @@ -64,6 +64,10 @@ func (l *Listener) HandlerStarted(ctx context.Context, msg json.RawMessage) cont return ctx } + if l.universalInstrumentation && l.extensionManager.IsExtensionRunning() { + ctx = l.extensionManager.SendStartInvocationRequest(ctx, msg) + } + ctx, _ = contextWithRootTraceContext(ctx, msg, l.mergeXrayTraces, l.traceContextExtractor) if !tracerInitialized { @@ -77,15 +81,11 @@ func (l *Listener) HandlerStarted(ctx context.Context, msg json.RawMessage) cont } isDdServerlessSpan := l.universalInstrumentation && l.extensionManager.IsExtensionRunning() - functionExecutionSpan = startFunctionExecutionSpan(ctx, l.mergeXrayTraces, isDdServerlessSpan) + functionExecutionSpan, ctx = startFunctionExecutionSpan(ctx, l.mergeXrayTraces, isDdServerlessSpan) // Add the span to the context so the user can create child spans ctx = tracer.ContextWithSpan(ctx, functionExecutionSpan) - if l.universalInstrumentation && l.extensionManager.IsExtensionRunning() { - ctx = l.extensionManager.SendStartInvocationRequest(ctx, msg) - } - return ctx } @@ -104,7 +104,7 @@ func (l *Listener) HandlerFinished(ctx context.Context, err error) { // startFunctionExecutionSpan starts a span that represents the current Lambda function execution // and returns the span so that it can be finished when the function execution is complete -func startFunctionExecutionSpan(ctx context.Context, mergeXrayTraces bool, isDdServerlessSpan bool) tracer.Span { +func startFunctionExecutionSpan(ctx context.Context, mergeXrayTraces bool, isDdServerlessSpan bool) (tracer.Span, context.Context) { // Extract information from context lambdaCtx, _ := lambdacontext.FromContext(ctx) rootTraceContext, ok := ctx.Value(traceContextKey).(TraceContext) @@ -149,7 +149,9 @@ func startFunctionExecutionSpan(ctx context.Context, mergeXrayTraces bool, isDdS span.SetTag("_dd.parent_source", "xray") } - return span + ctx = context.WithValue(ctx, extension.DdSpanId, fmt.Sprint(span.Context().SpanID())) + + return span, ctx } func separateVersionFromFunctionArn(functionArn string) (arnWithoutVersion string, functionVersion string) { diff --git a/internal/trace/listener_test.go b/internal/trace/listener_test.go index 1e5176af..28d804df 100644 --- a/internal/trace/listener_test.go +++ b/internal/trace/listener_test.go @@ -10,6 +10,7 @@ package trace import ( "context" + "fmt" "testing" "github.com/DataDog/datadog-lambda-go/internal/extension" @@ -75,7 +76,7 @@ func TestStartFunctionExecutionSpanFromXrayWithMergeEnabled(t *testing.T) { mt := mocktracer.Start() defer mt.Stop() - span := startFunctionExecutionSpan(ctx, true, false) + span, ctx := startFunctionExecutionSpan(ctx, true, false) span.Finish() finishedSpan := mt.FinishedSpans()[0] @@ -91,6 +92,7 @@ func TestStartFunctionExecutionSpanFromXrayWithMergeEnabled(t *testing.T) { assert.Equal(t, "mockfunctionname", finishedSpan.Tag("functionname")) assert.Equal(t, "serverless", finishedSpan.Tag("span.type")) assert.Equal(t, "xray", finishedSpan.Tag("_dd.parent_source")) + assert.Equal(t, fmt.Sprint(span.Context().SpanID()), ctx.Value(extension.DdSpanId).(string)) } func TestStartFunctionExecutionSpanFromXrayWithMergeDisabled(t *testing.T) { @@ -105,11 +107,12 @@ func TestStartFunctionExecutionSpanFromXrayWithMergeDisabled(t *testing.T) { mt := mocktracer.Start() defer mt.Stop() - span := startFunctionExecutionSpan(ctx, false, false) + span, ctx := startFunctionExecutionSpan(ctx, false, false) span.Finish() finishedSpan := mt.FinishedSpans()[0] assert.Equal(t, nil, finishedSpan.Tag("_dd.parent_source")) + assert.Equal(t, fmt.Sprint(span.Context().SpanID()), ctx.Value(extension.DdSpanId).(string)) } func TestStartFunctionExecutionSpanFromEventWithMergeEnabled(t *testing.T) { @@ -124,11 +127,12 @@ func TestStartFunctionExecutionSpanFromEventWithMergeEnabled(t *testing.T) { mt := mocktracer.Start() defer mt.Stop() - span := startFunctionExecutionSpan(ctx, true, false) + span, ctx := startFunctionExecutionSpan(ctx, true, false) span.Finish() finishedSpan := mt.FinishedSpans()[0] assert.Equal(t, "xray", finishedSpan.Tag("_dd.parent_source")) + assert.Equal(t, fmt.Sprint(span.Context().SpanID()), ctx.Value(extension.DdSpanId).(string)) } func TestStartFunctionExecutionSpanFromEventWithMergeDisabled(t *testing.T) { @@ -143,11 +147,12 @@ func TestStartFunctionExecutionSpanFromEventWithMergeDisabled(t *testing.T) { mt := mocktracer.Start() defer mt.Stop() - span := startFunctionExecutionSpan(ctx, false, false) + span, ctx := startFunctionExecutionSpan(ctx, false, false) span.Finish() finishedSpan := mt.FinishedSpans()[0] assert.Equal(t, nil, finishedSpan.Tag("_dd.parent_source")) + assert.Equal(t, fmt.Sprint(span.Context().SpanID()), ctx.Value(extension.DdSpanId).(string)) } func TestStartFunctionExecutionSpanWithExtension(t *testing.T) { @@ -162,9 +167,10 @@ func TestStartFunctionExecutionSpanWithExtension(t *testing.T) { mt := mocktracer.Start() defer mt.Stop() - span := startFunctionExecutionSpan(ctx, false, true) + span, ctx := startFunctionExecutionSpan(ctx, false, true) span.Finish() finishedSpan := mt.FinishedSpans()[0] assert.Equal(t, string(extension.DdSeverlessSpan), finishedSpan.Tag("resource.name")) + assert.Equal(t, fmt.Sprint(span.Context().SpanID()), ctx.Value(extension.DdSpanId).(string)) }