diff --git a/src/index.spec.ts b/src/index.spec.ts index d1b545f4..d5915806 100644 --- a/src/index.spec.ts +++ b/src/index.spec.ts @@ -63,6 +63,10 @@ jest.mock("./trace/trace-context-service", () => { get currentTraceHeaders() { return mockTraceHeaders; } + + reset() { + // mocking + } } return { ...jest.requireActual("./trace/trace-context-service"), diff --git a/src/trace/listener.spec.ts b/src/trace/listener.spec.ts index 2f8e2bb3..1e8dc189 100644 --- a/src/trace/listener.spec.ts +++ b/src/trace/listener.spec.ts @@ -68,6 +68,9 @@ jest.mock("./trace-context-service", () => { get currentTraceContext() { return mockSpanContextWrapper; } + reset() { + // mocking + } } return { ...jest.requireActual("./trace-context-service"), diff --git a/src/trace/listener.ts b/src/trace/listener.ts index 7003fb90..1e312a71 100644 --- a/src/trace/listener.ts +++ b/src/trace/listener.ts @@ -279,9 +279,10 @@ export class TraceListener { this.injectAuthorizerSpan(result, event?.requestContext?.requestId, finishTime || Date.now()); } - // Reset singleton + // Reset singletons and trace context this.stepFunctionContext = undefined; StepFunctionContextService.reset(); + this.contextService.reset(); } public onWrap any>(func: T): T { diff --git a/src/trace/trace-context-service.spec.ts b/src/trace/trace-context-service.spec.ts index 01fbf815..2e9c51bb 100644 --- a/src/trace/trace-context-service.spec.ts +++ b/src/trace/trace-context-service.spec.ts @@ -79,4 +79,51 @@ describe("TraceContextService", () => { expect(currentTraceContext?.sampleMode()).toBe(1); expect(currentTraceContext?.source).toBe("xray"); }); + + it("resets rootTraceContext to prevent caching between invocations", () => { + // Initial trace context + traceContextService["rootTraceContext"] = { + toTraceId: () => "123456", + toSpanId: () => "abcdef", + sampleMode: () => 1, + source: TraceSource.Event, + spanContext: spanContext, + }; + + expect(traceContextService.currentTraceContext).not.toBeNull(); + expect(traceContextService.traceSource).toBe("event"); + + traceContextService.reset(); + + expect(traceContextService.currentTraceContext).toBeNull(); + expect(traceContextService.traceSource).toBeNull(); + }); + + it("automatically resets trace context at the beginning of extract", async () => { + // Mock the extractor to return a specific context + const mockExtract = jest.fn().mockResolvedValue({ + toTraceId: () => "newTraceId", + toSpanId: () => "newSpanId", + sampleMode: () => 1, + source: TraceSource.Event, + spanContext: {}, + }); + traceContextService["traceExtractor"] = { extract: mockExtract } as any; + + // Set up old trace context (simulating previous invocation) + traceContextService["rootTraceContext"] = { + toTraceId: () => "oldTraceId", + toSpanId: () => "oldSpanId", + sampleMode: () => 0, + source: TraceSource.Xray, + spanContext: {}, + }; + + // Extract should reset and set new context + const result = await traceContextService.extract({}, {} as any); + + // Verify old context was cleared and new context was set + expect(result?.toTraceId()).toBe("newTraceId"); + expect(traceContextService.traceSource).toBe("event"); + }); }); diff --git a/src/trace/trace-context-service.ts b/src/trace/trace-context-service.ts index 8e1d2b55..247e7ff7 100644 --- a/src/trace/trace-context-service.ts +++ b/src/trace/trace-context-service.ts @@ -50,6 +50,9 @@ export class TraceContextService { } async extract(event: any, context: Context): Promise { + // Reset trace context from previous invocation to prevent caching + this.rootTraceContext = null; + this.rootTraceContext = await this.traceExtractor?.extract(event, context); return this.currentTraceContext; @@ -82,4 +85,8 @@ export class TraceContextService { get traceSource() { return this.rootTraceContext !== null ? this.rootTraceContext?.source : null; } + + reset() { + this.rootTraceContext = null; + } }