Skip to content
5 changes: 4 additions & 1 deletion internal/extension/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ func (em *ExtensionManager) SendStartInvocationRequest(ctx context.Context, even
if traceId != "" {
ctx = context.WithValue(ctx, DdTraceId, traceId)
}
parentId := response.Header.Get(string(DdParentId))
parentId := traceId
if pid := response.Header.Get(string(DdParentId)); pid != "" {
parentId = pid
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the parent id is not set, later on we will throw out the found trace context completely.

if parentId != "" {
ctx = context.WithValue(ctx, DdParentId, parentId)
}
Expand Down
23 changes: 23 additions & 0 deletions internal/extension/extension_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,29 @@ func TestExtensionStartInvokeWithTraceContext(t *testing.T) {
assert.Equal(t, mockSamplingPriority, samplingPriority)
}

func TestExtensionStartInvokeWithTraceContextNoParentID(t *testing.T) {
headers := http.Header{}
headers.Set(string(DdTraceId), mockTraceId)
headers.Set(string(DdSamplingPriority), mockSamplingPriority)

em := &ExtensionManager{
startInvocationUrl: startInvocationUrl,
httpClient: &ClientSuccessStartInvoke{
headers: headers,
},
}
ctx := em.SendStartInvocationRequest(context.TODO(), []byte{})
traceId := ctx.Value(DdTraceId)
parentId := ctx.Value(DdParentId)
samplingPriority := ctx.Value(DdSamplingPriority)
err := em.Flush()

assert.Nil(t, err)
assert.Equal(t, mockTraceId, traceId)
assert.Equal(t, mockTraceId, parentId)
assert.Equal(t, mockSamplingPriority, samplingPriority)
}

func TestExtensionEndInvocation(t *testing.T) {
em := &ExtensionManager{
endInvocationUrl: endInvocationUrl,
Expand Down
32 changes: 24 additions & 8 deletions internal/trace/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"strconv"
"strings"

"github.com/DataDog/datadog-lambda-go/internal/extension"
"github.com/DataDog/datadog-lambda-go/internal/logger"
"github.com/aws/aws-xray-sdk-go/header"
"github.com/aws/aws-xray-sdk-go/xray"
Expand Down Expand Up @@ -47,7 +48,7 @@ var DefaultTraceExtractor = getHeadersFromEventHeaders
// contextWithRootTraceContext uses the incoming event and context object payloads to determine
// the root TraceContext and then adds that TraceContext to the context object.
func contextWithRootTraceContext(ctx context.Context, ev json.RawMessage, mergeXrayTraces bool, extractor ContextExtractor) (context.Context, error) {
datadogTraceContext, gotDatadogTraceContext := getTraceContext(extractor(ctx, ev))
datadogTraceContext, gotDatadogTraceContext := getTraceContext(ctx, extractor(ctx, ev))

xrayTraceContext, errGettingXrayContext := convertXrayTraceContextFromLambdaContext(ctx)
if errGettingXrayContext != nil {
Expand Down Expand Up @@ -126,21 +127,36 @@ func createDummySubsegmentForXrayConverter(ctx context.Context, traceCtx TraceCo
return nil
}

func getTraceContext(context map[string]string) (TraceContext, bool) {
func getTraceContext(ctx context.Context, headers map[string]string) (TraceContext, bool) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the extension is in charge of parsing for trace context, it is included in the response from calling start invocation. It is then added to the context object. If present, it therefore should be read in case there is no trace context found in the event payload headers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by event payload headers, you mean the actual event that triggered the go lambda right?
Just wanna make sure I understand right, so we get trace context by either pinging the extension here which populates ctx

and headers could have trace context via some event such as api gw like this?

So to be sure we extract and attach context we do check both ways.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is exactly correct. Reading the "event payload headers" is hold over from before we added universal instrumentation. We could potentially get rid of it completely and just rely on the extension to do this work for us.

tc := TraceContext{}

traceID, ok := context[traceIDHeader]
if !ok {
traceID := headers[traceIDHeader]
if traceID == "" {
if val, ok := ctx.Value(extension.DdTraceId).(string); ok {
traceID = val
}
}
if traceID == "" {
return tc, false
}

parentID, ok := context[parentIDHeader]
if !ok {
parentID := headers[parentIDHeader]
if parentID == "" {
if val, ok := ctx.Value(extension.DdParentId).(string); ok {
parentID = val
}
}
if parentID == "" {
return tc, false
}

samplingPriority, ok := context[samplingPriorityHeader]
if !ok {
samplingPriority := headers[samplingPriorityHeader]
if samplingPriority == "" {
if val, ok := ctx.Value(extension.DdSamplingPriority).(string); ok {
samplingPriority = val
}
}
if samplingPriority == "" {
samplingPriority = "1" //sampler-keep
}

Expand Down
84 changes: 77 additions & 7 deletions internal/trace/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@ import (
"io/ioutil"
"testing"

"github.com/DataDog/datadog-lambda-go/internal/extension"
"github.com/aws/aws-xray-sdk-go/header"

"github.com/aws/aws-xray-sdk-go/xray"

"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -45,6 +44,20 @@ func mockLambdaXRayTraceContext(ctx context.Context, traceID, parentID string, s
return context.WithValue(ctx, xray.LambdaTraceHeaderKey, headerString)
}

func mockTraceContext(traceID, parentID, samplingPriority string) context.Context {
ctx := context.Background()
if traceID != "" {
ctx = context.WithValue(ctx, extension.DdTraceId, traceID)
}
if parentID != "" {
ctx = context.WithValue(ctx, extension.DdParentId, parentID)
}
if samplingPriority != "" {
ctx = context.WithValue(ctx, extension.DdSamplingPriority, samplingPriority)
}
return ctx
}

func loadRawJSON(t *testing.T, filename string) *json.RawMessage {
bytes, err := ioutil.ReadFile(filename)
if err != nil {
Expand All @@ -60,7 +73,7 @@ func TestGetDatadogTraceContextForTraceMetadataNonProxyEvent(t *testing.T) {
ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true)
ev := loadRawJSON(t, "../testdata/apig-event-with-headers.json")

headers, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev))
headers, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev))
assert.True(t, ok)

expected := TraceContext{
Expand All @@ -75,7 +88,7 @@ func TestGetDatadogTraceContextForTraceMetadataWithMixedCaseHeaders(t *testing.T
ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true)
ev := loadRawJSON(t, "../testdata/non-proxy-with-mixed-case-headers.json")

headers, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev))
headers, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev))
assert.True(t, ok)

expected := TraceContext{
Expand All @@ -90,7 +103,7 @@ func TestGetDatadogTraceContextForTraceMetadataWithMissingSamplingPriority(t *te
ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true)
ev := loadRawJSON(t, "../testdata/non-proxy-with-missing-sampling-priority.json")

headers, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev))
headers, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev))
assert.True(t, ok)

expected := TraceContext{
Expand All @@ -105,18 +118,75 @@ func TestGetDatadogTraceContextForInvalidData(t *testing.T) {
ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true)
ev := loadRawJSON(t, "../testdata/invalid.json")

_, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev))
_, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev))
assert.False(t, ok)
}

func TestGetDatadogTraceContextForMissingData(t *testing.T) {
ctx := mockLambdaXRayTraceContext(context.Background(), mockXRayTraceID, mockXRayEntityID, true)
ev := loadRawJSON(t, "../testdata/non-proxy-no-headers.json")

_, ok := getTraceContext(getHeadersFromEventHeaders(ctx, *ev))
_, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev))
assert.False(t, ok)
}

func TestGetDatadogTraceContextFromContextObject(t *testing.T) {
testcases := []struct {
traceID string
parentID string
samplingPriority string
expectTC TraceContext
expectOk bool
}{
{
"trace",
"parent",
"sampling",
TraceContext{
"x-datadog-trace-id": "trace",
"x-datadog-parent-id": "parent",
"x-datadog-sampling-priority": "sampling",
},
true,
},
{
"",
"parent",
"sampling",
TraceContext{},
false,
},
{
"trace",
"",
"sampling",
TraceContext{},
false,
},
{
"trace",
"parent",
"",
TraceContext{
"x-datadog-trace-id": "trace",
"x-datadog-parent-id": "parent",
"x-datadog-sampling-priority": "1",
},
true,
},
}

ev := loadRawJSON(t, "../testdata/non-proxy-no-headers.json")
for _, test := range testcases {
t.Run(test.traceID+test.parentID+test.samplingPriority, func(t *testing.T) {
ctx := mockTraceContext(test.traceID, test.parentID, test.samplingPriority)
tc, ok := getTraceContext(ctx, getHeadersFromEventHeaders(ctx, *ev))
assert.Equal(t, test.expectTC, tc)
assert.Equal(t, test.expectOk, ok)
})
}
}

func TestConvertXRayTraceID(t *testing.T) {
output, err := convertXRayTraceIDToDatadogTraceID(mockXRayTraceID)
assert.NoError(t, err)
Expand Down
16 changes: 9 additions & 7 deletions internal/trace/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ func (l *Listener) HandlerStarted(ctx context.Context, msg json.RawMessage) cont
return ctx
}

if l.universalInstrumentation && l.extensionManager.IsExtensionRunning() {
ctx = l.extensionManager.SendStartInvocationRequest(ctx, msg)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling SendStartInvocationRequest gives the extension the event payload. It then parses it for trace context, creates any inferred spans, and returns the trace context. Therefore, we need to get this trace context from the extension before starting the execution span here in the function.

}

ctx, _ = contextWithRootTraceContext(ctx, msg, l.mergeXrayTraces, l.traceContextExtractor)

if !tracerInitialized {
Expand All @@ -77,15 +81,11 @@ func (l *Listener) HandlerStarted(ctx context.Context, msg json.RawMessage) cont
}

isDdServerlessSpan := l.universalInstrumentation && l.extensionManager.IsExtensionRunning()
functionExecutionSpan = startFunctionExecutionSpan(ctx, l.mergeXrayTraces, isDdServerlessSpan)
functionExecutionSpan, ctx = startFunctionExecutionSpan(ctx, l.mergeXrayTraces, isDdServerlessSpan)

// Add the span to the context so the user can create child spans
ctx = tracer.ContextWithSpan(ctx, functionExecutionSpan)

if l.universalInstrumentation && l.extensionManager.IsExtensionRunning() {
ctx = l.extensionManager.SendStartInvocationRequest(ctx, msg)
}

return ctx
}

Expand All @@ -104,7 +104,7 @@ func (l *Listener) HandlerFinished(ctx context.Context, err error) {

// startFunctionExecutionSpan starts a span that represents the current Lambda function execution
// and returns the span so that it can be finished when the function execution is complete
func startFunctionExecutionSpan(ctx context.Context, mergeXrayTraces bool, isDdServerlessSpan bool) tracer.Span {
func startFunctionExecutionSpan(ctx context.Context, mergeXrayTraces bool, isDdServerlessSpan bool) (tracer.Span, context.Context) {
// Extract information from context
lambdaCtx, _ := lambdacontext.FromContext(ctx)
rootTraceContext, ok := ctx.Value(traceContextKey).(TraceContext)
Expand Down Expand Up @@ -149,7 +149,9 @@ func startFunctionExecutionSpan(ctx context.Context, mergeXrayTraces bool, isDdS
span.SetTag("_dd.parent_source", "xray")
}

return span
ctx = context.WithValue(ctx, extension.DdSpanId, fmt.Sprint(span.Context().SpanID()))
Copy link
Contributor Author

@purple4reina purple4reina Sep 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the span id to the context for use later when calling end invocation on the extension. (see

req.Header.Set(string(DdSpanId), spanId)
)

The extension uses this span id to replace the span id of the aws.lambda span that it creates.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly right. You'll notice that the tracer and the extension are creating aws.lambda spans. However, the extension knows to throw out the one made in the tracer. The extension will ensure that both of these aws.lambda spans have the same trace context (span id, trace, id, parent id).

We do it this way because we want to make sure the tracer has access to the aws.lambda span as the parent of any spans it creates itself.


return span, ctx
}

func separateVersionFromFunctionArn(functionArn string) (arnWithoutVersion string, functionVersion string) {
Expand Down
16 changes: 11 additions & 5 deletions internal/trace/listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package trace

import (
"context"
"fmt"
"testing"

"github.com/DataDog/datadog-lambda-go/internal/extension"
Expand Down Expand Up @@ -75,7 +76,7 @@ func TestStartFunctionExecutionSpanFromXrayWithMergeEnabled(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

span := startFunctionExecutionSpan(ctx, true, false)
span, ctx := startFunctionExecutionSpan(ctx, true, false)
span.Finish()
finishedSpan := mt.FinishedSpans()[0]

Expand All @@ -91,6 +92,7 @@ func TestStartFunctionExecutionSpanFromXrayWithMergeEnabled(t *testing.T) {
assert.Equal(t, "mockfunctionname", finishedSpan.Tag("functionname"))
assert.Equal(t, "serverless", finishedSpan.Tag("span.type"))
assert.Equal(t, "xray", finishedSpan.Tag("_dd.parent_source"))
assert.Equal(t, fmt.Sprint(span.Context().SpanID()), ctx.Value(extension.DdSpanId).(string))
}

func TestStartFunctionExecutionSpanFromXrayWithMergeDisabled(t *testing.T) {
Expand All @@ -105,11 +107,12 @@ func TestStartFunctionExecutionSpanFromXrayWithMergeDisabled(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

span := startFunctionExecutionSpan(ctx, false, false)
span, ctx := startFunctionExecutionSpan(ctx, false, false)
span.Finish()
finishedSpan := mt.FinishedSpans()[0]

assert.Equal(t, nil, finishedSpan.Tag("_dd.parent_source"))
assert.Equal(t, fmt.Sprint(span.Context().SpanID()), ctx.Value(extension.DdSpanId).(string))
}

func TestStartFunctionExecutionSpanFromEventWithMergeEnabled(t *testing.T) {
Expand All @@ -124,11 +127,12 @@ func TestStartFunctionExecutionSpanFromEventWithMergeEnabled(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

span := startFunctionExecutionSpan(ctx, true, false)
span, ctx := startFunctionExecutionSpan(ctx, true, false)
span.Finish()
finishedSpan := mt.FinishedSpans()[0]

assert.Equal(t, "xray", finishedSpan.Tag("_dd.parent_source"))
assert.Equal(t, fmt.Sprint(span.Context().SpanID()), ctx.Value(extension.DdSpanId).(string))
}

func TestStartFunctionExecutionSpanFromEventWithMergeDisabled(t *testing.T) {
Expand All @@ -143,11 +147,12 @@ func TestStartFunctionExecutionSpanFromEventWithMergeDisabled(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

span := startFunctionExecutionSpan(ctx, false, false)
span, ctx := startFunctionExecutionSpan(ctx, false, false)
span.Finish()
finishedSpan := mt.FinishedSpans()[0]

assert.Equal(t, nil, finishedSpan.Tag("_dd.parent_source"))
assert.Equal(t, fmt.Sprint(span.Context().SpanID()), ctx.Value(extension.DdSpanId).(string))
}

func TestStartFunctionExecutionSpanWithExtension(t *testing.T) {
Expand All @@ -162,9 +167,10 @@ func TestStartFunctionExecutionSpanWithExtension(t *testing.T) {
mt := mocktracer.Start()
defer mt.Stop()

span := startFunctionExecutionSpan(ctx, false, true)
span, ctx := startFunctionExecutionSpan(ctx, false, true)
span.Finish()
finishedSpan := mt.FinishedSpans()[0]

assert.Equal(t, string(extension.DdSeverlessSpan), finishedSpan.Tag("resource.name"))
assert.Equal(t, fmt.Sprint(span.Context().SpanID()), ctx.Value(extension.DdSpanId).(string))
}