diff --git a/packages/opentelemetry-context-async-hooks/src/AsyncHooksContextManager.ts b/packages/opentelemetry-context-async-hooks/src/AsyncHooksContextManager.ts index f970ab397da..370d40a7dad 100644 --- a/packages/opentelemetry-context-async-hooks/src/AsyncHooksContextManager.ts +++ b/packages/opentelemetry-context-async-hooks/src/AsyncHooksContextManager.ts @@ -19,6 +19,7 @@ import * as asyncHooks from 'async_hooks'; import { EventEmitter } from 'events'; type Func = (...args: unknown[]) => T; +type UnPromisify = T extends Promise ? U : T; type PatchedEventEmitter = { /** @@ -29,19 +30,6 @@ type PatchedEventEmitter = { __ot_listeners?: { [name: string]: WeakMap, Func> }; } & EventEmitter; -class Reference { - constructor(private _value: T) {} - - set(value: T) { - this._value = value; - return this; - } - - get() { - return this._value; - } -} - const ADD_LISTENER_METHODS = [ 'addListener' as 'addListener', 'on' as 'on', @@ -52,72 +40,63 @@ const ADD_LISTENER_METHODS = [ export class AsyncHooksContextManager implements ContextManager { private _asyncHook: asyncHooks.AsyncHook; - private _contextRefs: Map | undefined> = new Map(); + private _contexts: Map = new Map(); + private _stack: Array = []; + private _active: Context | undefined = undefined; constructor() { this._asyncHook = asyncHooks.createHook({ init: this._init.bind(this), + before: this._before.bind(this), destroy: this._destroy.bind(this), promiseResolve: this._destroy.bind(this), }); } active(): Context { - const ref = this._contextRefs.get(asyncHooks.executionAsyncId()); - return ref === undefined ? Context.ROOT_CONTEXT : ref.get(); + return this._active ?? Context.ROOT_CONTEXT; } with ReturnType>( context: Context, fn: T ): ReturnType { - const uid = asyncHooks.executionAsyncId(); - let ref = this._contextRefs.get(uid); - let oldContext: Context | undefined = undefined; - if (ref === undefined) { - ref = new Reference(context); - this._contextRefs.set(uid, ref); - } else { - oldContext = ref.get(); - ref.set(context); - } + this._enterContext(context); try { - return fn(); - } finally { - if (oldContext === undefined) { - this._destroy(uid); - } else { - ref.set(oldContext); - } + const result = fn(); + this._exitContext(); + return result; + } catch (err) { + this._exitContext(); + throw err; } } - async withAsync, U extends (...args: unknown[]) => T>( + /** + * Run the async fn callback with object set as the current active context + * + * NOTE: This method is experimental + * + * @param context Any object to set as the current active context + * @param fn A async function to be immediately run within a specific context + */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + async withAsync Promise, U = UnPromisify>>( context: Context, - fn: U - ): Promise { - const uid = asyncHooks.executionAsyncId(); - let ref = this._contextRefs.get(uid); - let oldContext: Context | undefined = undefined; - if (ref === undefined) { - ref = new Reference(context); - this._contextRefs.set(uid, ref); - } else { - oldContext = ref.get(); - ref.set(context); - } + asyncFn: T + ): Promise { + this._enterContext(context); try { - return await fn(); - } finally { - if (oldContext === undefined) { - this._destroy(uid); - } else { - ref.set(oldContext); - } + const result = await asyncFn(); + this._exitContext(); + return result; + } catch (err) { + this._exitContext(); + throw err; } } - bind(target: T, context: Context): T { + bind(target: T, context?: Context): T { // if no specific context to propagate is given, we use the current one if (context === undefined) { context = this.active(); @@ -137,7 +116,8 @@ export class AsyncHooksContextManager implements ContextManager { disable(): this { this._asyncHook.disable(); - this._contextRefs.clear(); + this._contexts.clear(); + this._active = undefined; return this; } @@ -156,6 +136,7 @@ export class AsyncHooksContextManager implements ContextManager { * It isn't possible to tell Typescript that contextWrapper is the same as T * so we forced to cast as any here. */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any return contextWrapper as any; } @@ -270,10 +251,10 @@ export class AsyncHooksContextManager implements ContextManager { * context as the current one if it exist. * @param uid id of the async context */ - private _init(uid: number) { - const ref = this._contextRefs.get(asyncHooks.executionAsyncId()); - if (ref !== undefined) { - this._contextRefs.set(uid, ref); + private _init(uid: number, type: string, triggerId: number) { + const context = this._contexts.get(triggerId) ?? this._active; + if (context !== undefined) { + this._contexts.set(uid, context); } } @@ -283,6 +264,33 @@ export class AsyncHooksContextManager implements ContextManager { * @param uid uid of the async context */ private _destroy(uid: number) { - this._contextRefs.delete(uid); + this._contexts.delete(uid); + } + + /** + * Before hook is called just beforing entering a async context. + * @param uid uid of the async context + */ + private _before(uid: number) { + const context = this._contexts.get(uid); + if (context !== undefined) { + this._enterContext(context); + } + } + + /** + * Set the given context as active + */ + private _enterContext(context: Context) { + this._stack.push(context); + this._active = context; + } + + /** + * Remove current context from the stack and set as active the last one. + */ + private _exitContext() { + this._stack.pop(); + this._active = this._stack[this._stack.length - 1]; } } diff --git a/packages/opentelemetry-context-async-hooks/test/AsyncHooksContextManager.test.ts b/packages/opentelemetry-context-async-hooks/test/AsyncHooksContextManager.test.ts index 1c2429694ee..a69158b6c97 100644 --- a/packages/opentelemetry-context-async-hooks/test/AsyncHooksContextManager.test.ts +++ b/packages/opentelemetry-context-async-hooks/test/AsyncHooksContextManager.test.ts @@ -18,6 +18,7 @@ import * as assert from 'assert'; import { AsyncHooksContextManager } from '../src'; import { EventEmitter } from 'events'; import { Context } from '@opentelemetry/context-base'; +import * as asynchooks from 'async_hooks'; describe('AsyncHooksContextManager', () => { let contextManager: AsyncHooksContextManager; @@ -102,6 +103,35 @@ describe('AsyncHooksContextManager', () => { return done(); }); }); + + it('should finally restore an old context', done => { + const ctx1 = Context.ROOT_CONTEXT.setValue(key1, 'ctx1'); + contextManager.with(ctx1, () => { + assert.strictEqual(contextManager.active(), ctx1); + setTimeout(() => { + assert.strictEqual(contextManager.active(), ctx1); + return done(); + }); + }); + }); + + it('async function called from nested "with" sync function should return nested context', done => { + const scope1 = '1' as any; + const scope2 = '2' as any; + + const asyncFuncCalledDownstreamFromSync = async () => { + await (async () => {})(); + assert.strictEqual(contextManager.active(), scope2); + return done(); + }; + + contextManager.with(scope1, () => { + assert.strictEqual(contextManager.active(), scope1); + contextManager.with(scope2, () => asyncFuncCalledDownstreamFromSync()); + assert.strictEqual(contextManager.active(), scope1); + }); + assert.strictEqual(contextManager.active(), Context.ROOT_CONTEXT); + }); }); describe('.withAsync()', () => { @@ -149,10 +179,11 @@ describe('AsyncHooksContextManager', () => { }); it('should finally restore an old scope', async () => { - const scope1 = '1' as any; - const scope2 = '2' as any; + const scope1 = { '1': 1 } as any; + const scope2 = { '2': 2 } as any; let done = false; + assert.strictEqual(contextManager.active(), Context.ROOT_CONTEXT); await contextManager.withAsync(scope1, async () => { assert.strictEqual(contextManager.active(), scope1); await contextManager.withAsync(scope2, async () => { @@ -161,9 +192,24 @@ describe('AsyncHooksContextManager', () => { }); assert.strictEqual(contextManager.active(), scope1); }); - + assert.strictEqual(contextManager.active(), scope2); assert.ok(done); }); + + it('should keep scope across async operations', async () => { + const scope1 = { '1': 1 } as any; + + await contextManager.withAsync(scope1, () => { + return new Promise(resolve => { + assert.strictEqual(contextManager.active(), scope1); + setTimeout(() => { + assert.strictEqual(contextManager.active(), scope1); + return resolve(); + }, 5); + }); + }); + assert.strictEqual(contextManager.active(), scope1); + }); }); describe('.withAsync/with()', () => { @@ -180,93 +226,33 @@ describe('AsyncHooksContextManager', () => { }); assert.strictEqual(contextManager.active(), scope1); }); + assert.strictEqual(contextManager.active(), scope1); assert.ok(done); }); - it('withAsync() inside with() should correctly restore conxtext', done => { + it('with() inside a setTimeout inside withAsync() should correctly restore context', async () => { const scope1 = '1' as any; const scope2 = '2' as any; - contextManager.with(scope1, async () => { - assert.strictEqual(contextManager.active(), scope1); - await contextManager.withAsync(scope2, async () => { - assert.strictEqual(contextManager.active(), scope2); - }); - assert.strictEqual(contextManager.active(), scope1); - return done(); - }); - assert.strictEqual(contextManager.active(), Context.ROOT_CONTEXT); - }); - - it('not awaited withAsync() inside with() should not restore context', done => { - const scope1 = '1' as any; - const scope2 = '2' as any; - let _done = false; - - contextManager.with(scope1, () => { - assert.strictEqual(contextManager.active(), scope1); - contextManager - .withAsync(scope2, async () => { - assert.strictEqual(contextManager.active(), scope2); - }) - .then(() => { + await contextManager + .withAsync(scope1, () => { + return new Promise(resolve => { assert.strictEqual(contextManager.active(), scope1); - _done = true; - }); - // in this case the current scope is 2 since we - // didnt waited the withAsync call - assert.strictEqual(contextManager.active(), scope2); - setTimeout(() => { - assert.strictEqual(contextManager.active(), scope1); - assert(_done); - return done(); - }, 100); - }); - assert.strictEqual(contextManager.active(), Context.ROOT_CONTEXT); - }); - - it('withAsync() inside a setTimeout inside a with() should correctly restore context', done => { - const scope1 = '1' as any; - const scope2 = '2' as any; - - contextManager.with(scope1, () => { - assert.strictEqual(contextManager.active(), scope1); - setTimeout(() => { - assert.strictEqual(contextManager.active(), scope1); - contextManager - .withAsync(scope2, async () => { - assert.strictEqual(contextManager.active(), scope2); - }) - .then(() => { + setTimeout(() => { assert.strictEqual(contextManager.active(), scope1); - return done(); - }); - }, 5); - assert.strictEqual(contextManager.active(), scope1); - }); - assert.strictEqual(contextManager.active(), Context.ROOT_CONTEXT); - }); - - it('with() inside a setTimeout inside withAsync() should correctly restore context', done => { - const scope1 = '1' as any; - const scope2 = '2' as any; - - contextManager - .withAsync(scope1, async () => { - assert.strictEqual(contextManager.active(), scope1); - setTimeout(() => { + contextManager.with(scope2, () => { + assert.strictEqual(contextManager.active(), scope2); + return resolve(); + }); + }, 5); assert.strictEqual(contextManager.active(), scope1); - contextManager.with(scope2, () => { - assert.strictEqual(contextManager.active(), scope2); - return done(); - }); - }, 5); - assert.strictEqual(contextManager.active(), scope1); + }); }) .then(() => { assert.strictEqual(contextManager.active(), scope1); }); + assert.strictEqual(contextManager.active(), scope1); }); }); @@ -320,31 +306,15 @@ describe('AsyncHooksContextManager', () => { fn(); }); - it('should fail to return current context (when disabled + async op)', done => { - contextManager.disable(); - const context = Context.ROOT_CONTEXT.setValue(key1, 1); - const fn = contextManager.bind(() => { - setTimeout(() => { - assert.strictEqual( - contextManager.active(), - Context.ROOT_CONTEXT, - 'should have no context' - ); - return done(); - }, 100); - }, context); - fn(); - }); - - it('should return current context (when re-enabled + async op)', done => { - contextManager.enable(); + it('should fail to return current context with async op', done => { const context = Context.ROOT_CONTEXT.setValue(key1, 1); const fn = contextManager.bind(() => { + assert.strictEqual(contextManager.active(), context); setTimeout(() => { assert.strictEqual( contextManager.active(), context, - 'should have context' + 'should have no context' ); return done(); }, 100); @@ -363,7 +333,6 @@ describe('AsyncHooksContextManager', () => { const ee = new EventEmitter(); contextManager.disable(); assert.deepStrictEqual(contextManager.bind(ee, Context.ROOT_CONTEXT), ee); - contextManager.enable(); }); it('should return current context and removeListener (when enabled)', done => { @@ -409,7 +378,6 @@ describe('AsyncHooksContextManager', () => { assert.deepStrictEqual(contextManager.active(), context); patchedEe.removeListener('test', handler); assert.strictEqual(patchedEe.listeners('test').length, 0); - contextManager.enable(); return done(); }; patchedEe.on('test', handler); @@ -417,30 +385,12 @@ describe('AsyncHooksContextManager', () => { patchedEe.emit('test'); }); - it('should not return current context (when disabled + async op)', done => { - contextManager.disable(); - const ee = new EventEmitter(); - const context = Context.ROOT_CONTEXT.setValue(key1, 1); - const patchedEe = contextManager.bind(ee, context); - const handler = () => { - setImmediate(() => { - assert.deepStrictEqual(contextManager.active(), Context.ROOT_CONTEXT); - patchedEe.removeAllListeners('test'); - assert.strictEqual(patchedEe.listeners('test').length, 0); - return done(); - }); - }; - patchedEe.on('test', handler); - assert.strictEqual(patchedEe.listeners('test').length, 1); - patchedEe.emit('test'); - }); - - it('should return current context (when enabled + async op)', done => { - contextManager.enable(); + it('should not return current context with async op', done => { const ee = new EventEmitter(); const context = Context.ROOT_CONTEXT.setValue(key1, 1); const patchedEe = contextManager.bind(ee, context); const handler = () => { + assert.deepStrictEqual(contextManager.active(), context); setImmediate(() => { assert.deepStrictEqual(contextManager.active(), context); patchedEe.removeAllListeners('test');