diff --git a/index.js b/index.js index 425da1d715..04805bc78a 100644 --- a/index.js +++ b/index.js @@ -12,8 +12,11 @@ */ const newrelic = require('newrelic') newrelic.instrumentConglomerate('aws-sdk', require('./lib/instrumentation')) -newrelic.instrumentMessages('@aws-sdk/client-sns', require('./lib/v3-sns')) newrelic.instrument({ moduleName: '@aws-sdk/smithy-client', onResolved: require('./lib/smithy-client') }) +newrelic.instrumentMessages({ + moduleName: '@aws-sdk/client-sns', + onResolved: require('./lib/v3-sns') +}) diff --git a/lib/v3-sns.js b/lib/v3-sns.js index 128058e873..bedd4ea3b1 100644 --- a/lib/v3-sns.js +++ b/lib/v3-sns.js @@ -5,32 +5,100 @@ 'use strict' -function wrapClientSend(shim, original, name, args) { - const { constructor, input } = args[0] - const type = constructor.name - if (type === 'PublishCommand') { - return { - callback: shim.LAST, - destinationName: getDestinationName(input), - destinationType: shim.TOPIC, - opaque: true - } +module.exports = function instrument(shim, name, resolvedName) { + const fileNameIndex = resolvedName.indexOf('/index') + const relativeFolder = resolvedName.substr(0, fileNameIndex) + + // The path changes depending on the version... so we don't want to hard-code the relative + // path from the module root. + const snsClientExport = shim.require(`${relativeFolder}/SNSClient`) + + if (!shim.isFunction(snsClientExport.SNSClient)) { + shim.logger.debug('Could not find SNSClient, not instrumenting.') + return } - // eslint-disable-next-line consistent-return - return + shim.setLibrary(shim.SNS) + shim.wrapReturn( + snsClientExport, + 'SNSClient', + function wrappedReturn(shim, original, fnName, instance) { + postClientConstructor.call(instance, shim) + } + ) } -function getDestinationName({ TopicArn, TargetArn }) { - return TopicArn || TargetArn || 'PhoneNumber' // We don't want the value of PhoneNumber +/** + * Calls the instances middlewareStack.use to register + * a plugin that adds a middleware to record the time it teakes to publish a message + * see: https://aws.amazon.com/blogs/developer/middleware-stack-modular-aws-sdk-js/ + * + * @param {Shim} shim + */ +function postClientConstructor(shim) { + this.middlewareStack.use(getPlugin(shim)) } -module.exports = function instrument(shim, AWS) { - if (!shim.isFunction(AWS.SNS)) { - shim.logger.debug('Could not find SNS, not instrumenting.') - return +/** + * Returns the plugin object that adds middleware + * + * @param {Shim} shim + * @returns {object} + */ +function getPlugin(shim) { + return { + applyToStack: (clientStack) => { + clientStack.add(snsMiddleware.bind(null, shim), { + name: 'NewRelicSnsMiddleware', + step: 'initialize', + priority: 'high' + }) + } + } +} + +/** + * Middleware hook that records the middleware chain + * when command is `PublishCommand` + * + * @param {Shim} shim + * @param {function} next middleware function + * @param {Object} context + * @returns {function} + */ +function snsMiddleware(shim, next, context) { + if (context.commandName === 'PublishCommand') { + return shim.recordProduce(next, getSnsSpec) } - shim.setLibrary(shim.SNS) - shim.recordProduce(AWS.SNSClient.prototype, 'send', wrapClientSend) + return next +} + +/** + * Returns the spec for PublishCommand + * + * @param {Shim} shim + * @param {original} original original middleware function + * @param {Array} args to the middleware function + * @returns {Object} + */ +function getSnsSpec(shim, original, name, args) { + const [command] = args + return { + promise: true, + callback: shim.LAST, + destinationName: getDestinationName(command.input), + destinationType: shim.TOPIC, + opaque: true + } +} + +/** + * Helper to set the appropriate destinationName based on + * the command input + * + * @param {Object} + */ +function getDestinationName({ TopicArn, TargetArn }) { + return TopicArn || TargetArn || 'PhoneNumber' // We don't want the value of PhoneNumber } diff --git a/nr-hooks.js b/nr-hooks.js index e6d6753574..8ae672d851 100644 --- a/nr-hooks.js +++ b/nr-hooks.js @@ -14,7 +14,7 @@ module.exports = [ { type: 'message', moduleName: '@aws-sdk/client-sns', - onRequire: require('./lib/v3-sns') + onResolved: require('./lib/v3-sns') }, { type: 'generic', diff --git a/tests/versioned/aws-sdk-v3/sns.tap.js b/tests/versioned/aws-sdk-v3/sns.tap.js index 3b2157738b..812f3d150e 100644 --- a/tests/versioned/aws-sdk-v3/sns.tap.js +++ b/tests/versioned/aws-sdk-v3/sns.tap.js @@ -34,7 +34,7 @@ tap.test('SNS', (t) => { helper.registerInstrumentation({ moduleName: '@aws-sdk/client-sns', type: 'message', - onRequire: require('../../../lib/v3-sns') + onResolved: require('../../../lib/v3-sns') }) const lib = require('@aws-sdk/client-sns') const SNSClient = lib.SNSClient @@ -49,7 +49,7 @@ tap.test('SNS', (t) => { }) t.afterEach(() => { - server.close() + server.destroy() server = null // this may be brute force but i could not figure out // which files within the modules were cached preventing the instrumenting @@ -62,6 +62,22 @@ tap.test('SNS', (t) => { helper && helper.unload() }) + t.test('publish with callback', (t) => { + helper.runInTransaction((tx) => { + const params = { Message: 'Hello!' } + + const cmd = new PublishCommand(params) + sns.send(cmd, (err) => { + t.error(err) + tx.end() + + const destName = 'PhoneNumber' + const args = [t, tx, destName] + setImmediate(finish, ...args) + }) + }) + }) + t.test('publish with default destination(PhoneNumber)', (t) => { helper.runInTransaction(async (tx) => { const params = { Message: 'Hello!' } diff --git a/tests/versioned/aws-server-stubs/response-server/index.js b/tests/versioned/aws-server-stubs/response-server/index.js index 418dd0ef1b..784bdf874f 100644 --- a/tests/versioned/aws-server-stubs/response-server/index.js +++ b/tests/versioned/aws-server-stubs/response-server/index.js @@ -29,6 +29,22 @@ function createResponseServer() { res.end() }) + // server.destroy: close, but faster! + // tracks and manually closes any open sockets + const sockets = new Set() + server.on('connection', (socket) => { + sockets.add(socket) + socket.once('close', () => { + sockets.delete(socket) + }) + }) + server.destroy = function () { + sockets.forEach((socket) => { + socket.destroy() + }) + server.close() + } + return server }