diff --git a/api.js b/api.js index d37eed31d2..719ebc69f6 100644 --- a/api.js +++ b/api.js @@ -1839,7 +1839,7 @@ API.prototype.setErrorGroupCallback = function setErrorGroupCallback(callback) { ) metric.incrementCallCount() - if (!this.shim.isFunction(callback) || this.shim.isPromise(callback)) { + if (!this.shim.isFunction(callback) || this.shim.isAsyncFunction(callback)) { logger.warn( 'Error Group callback must be a synchronous function, Error Group attribute will not be added' ) @@ -1849,4 +1849,37 @@ API.prototype.setErrorGroupCallback = function setErrorGroupCallback(callback) { this.agent.errors.errorGroupCallback = callback } +/** + * Registers a callback which will be used for calculating token counts on Llm events when they are not + * available. This function will typically only be used if `ai_monitoring.record_content.enabled` is false + * and you want to still capture token counts for Llm events. + * + * Provided callbacks must return an integer value for the token count for a given piece of content. + * + * @param {Function} callback - synchronous function called to calculate token count for content. + * @example + * // @param {string} model - name of model (i.e. gpt-3.5-turbo) + * // @param {string} content - prompt or completion response + * function tokenCallback(model, content) { + * // calculate tokens based on model and content + * // return token count + * return 40 + * } + */ +API.prototype.setLlmTokenCountCallback = function setLlmTokenCountCallback(callback) { + const metric = this.agent.metrics.getOrCreateMetric( + NAMES.SUPPORTABILITY.API + '/setLlmTokenCountCallback' + ) + metric.incrementCallCount() + + if (!this.shim.isFunction(callback) || this.shim.isAsyncFunction(callback)) { + logger.warn( + 'Llm token count callback must be a synchronous function, callback will not be registered.' + ) + return + } + + this.agent.llm.tokenCountCallback = callback +} + module.exports = API diff --git a/lib/instrumentation/restify.js b/lib/instrumentation/restify.js index 81240c9dff..85960bf157 100644 --- a/lib/instrumentation/restify.js +++ b/lib/instrumentation/restify.js @@ -93,7 +93,7 @@ function wrapMiddleware(shim, middleware, _name, route) { }) const wrappedMw = shim.recordMiddleware(middleware, spec) - if (middleware.constructor.name === 'AsyncFunction') { + if (shim.isAsyncFunction(middleware)) { return async function asyncShim() { return wrappedMw.apply(this, arguments) } diff --git a/lib/llm-events/openai/chat-completion-message.js b/lib/llm-events/openai/chat-completion-message.js index d45a222d07..ebfbe57a7e 100644 --- a/lib/llm-events/openai/chat-completion-message.js +++ b/lib/llm-events/openai/chat-completion-message.js @@ -20,9 +20,13 @@ module.exports = class LlmChatCompletionMessage extends LlmEvent { } if (this.is_response) { - this.token_count = response?.usage?.completion_tokens + this.token_count = + response?.usage?.completion_tokens || + agent.llm?.tokenCountCallback?.(this['response.model'], message?.content) } else { - this.token_count = response?.usage?.prompt_tokens + this.token_count = + response?.usage?.prompt_tokens || + agent.llm?.tokenCountCallback?.(request.model || request.engine, message?.content) } } } diff --git a/lib/llm-events/openai/embedding.js b/lib/llm-events/openai/embedding.js index 0aaa5e65fb..74a60e855e 100644 --- a/lib/llm-events/openai/embedding.js +++ b/lib/llm-events/openai/embedding.js @@ -14,6 +14,8 @@ module.exports = class LlmEmbedding extends LlmEvent { if (agent.config.ai_monitoring.record_content.enabled === true) { this.input = request.input?.toString() } - this.token_count = response?.usage?.prompt_tokens + this.token_count = + response?.usage?.prompt_tokens || + agent.llm?.tokenCountCallback?.(this['request.model'], request.input?.toString()) } } diff --git a/lib/shim/shim.js b/lib/shim/shim.js index 7d09ae6c72..f3c8cf0d9f 100644 --- a/lib/shim/shim.js +++ b/lib/shim/shim.js @@ -125,6 +125,7 @@ Shim.prototype.getName = getName Shim.prototype.isObject = isObject Shim.prototype.isFunction = isFunction Shim.prototype.isPromise = isPromise +Shim.prototype.isAsyncFunction = isAsyncFunction Shim.prototype.isString = isString Shim.prototype.isNumber = isNumber Shim.prototype.isBoolean = isBoolean @@ -1345,6 +1346,20 @@ function isPromise(obj) { return obj && typeof obj.then === 'function' } +/** + * Determines if function is an async function. + * Note it does not test if the return value of function is a + * promise or async function + * + * @memberof Shim.prototype + * @param fn + * @param (function) function to test if async + * @returns {boolean} True if the function is an async function + */ +function isAsyncFunction(fn) { + return fn.constructor.name === 'AsyncFunction' +} + /** * Determines if the given value is null. * diff --git a/test/unit/api/api-llm.test.js b/test/unit/api/api-llm.test.js index 542a4f9087..a2fb477d7c 100644 --- a/test/unit/api/api-llm.test.js +++ b/test/unit/api/api-llm.test.js @@ -28,7 +28,8 @@ tap.test('Agent API LLM methods', (t) => { loggerMock.warn.reset() const agent = helper.loadMockedAgent() t.context.api = new API(agent) - t.context.api.agent.config.ai_monitoring.enabled = true + agent.config.ai_monitoring.enabled = true + t.context.agent = agent }) t.afterEach((t) => { @@ -119,4 +120,49 @@ tap.test('Agent API LLM methods', (t) => { }) }) }) + + t.test('setLlmTokenCount should register callback to calculate token counts', async (t) => { + const { api, agent } = t.context + function callback(model, content) { + if (model === 'foo' && content === 'bar') { + return 10 + } + + return 1 + } + api.setLlmTokenCountCallback(callback) + t.same(agent.llm.tokenCountCallback, callback) + }) + + t.test('should not store token count callback if it is async', async (t) => { + const { api, agent } = t.context + async function callback(model, content) { + return await new Promise((resolve) => { + if (model === 'foo' && content === 'bar') { + resolve(10) + } + }) + } + api.setLlmTokenCountCallback(callback) + t.same(agent.llm.tokenCountCallback, undefined) + t.equal(loggerMock.warn.callCount, 1) + t.equal( + loggerMock.warn.args[0][0], + 'Llm token count callback must be a synchronous function, callback will not be registered.' + ) + }) + + t.test( + 'should not store token count callback if callback is not actually a function', + async (t) => { + const { api, agent } = t.context + api.setLlmTokenCountCallback({ unit: 'test' }) + t.same(agent.llm.tokenCountCallback, undefined) + t.equal(loggerMock.warn.callCount, 1) + t.equal( + loggerMock.warn.args[0][0], + 'Llm token count callback must be a synchronous function, callback will not be registered.' + ) + } + ) }) diff --git a/test/unit/api/api-set-error-group-callback.test.js b/test/unit/api/api-set-error-group-callback.test.js index 2a31f18008..defdf8e6df 100644 --- a/test/unit/api/api-set-error-group-callback.test.js +++ b/test/unit/api/api-set-error-group-callback.test.js @@ -84,8 +84,8 @@ tap.test('Agent API = set Error Group callback', (t) => { }) t.test('should not attach the callback when async function', (t) => { - function callback() { - return new Promise((resolve) => { + async function callback() { + return await new Promise((resolve) => { setTimeout(() => { resolve() }, 200) diff --git a/test/unit/api/stub.test.js b/test/unit/api/stub.test.js index 858c5fd243..980b24a307 100644 --- a/test/unit/api/stub.test.js +++ b/test/unit/api/stub.test.js @@ -8,7 +8,7 @@ const tap = require('tap') const API = require('../../../stub_api') -const EXPECTED_API_COUNT = 34 +const EXPECTED_API_COUNT = 35 tap.test('Agent API - Stubbed Agent API', (t) => { t.autoend() diff --git a/test/unit/llm-events/openai/chat-completion-message.test.js b/test/unit/llm-events/openai/chat-completion-message.test.js index dc80b6305c..a42a725287 100644 --- a/test/unit/llm-events/openai/chat-completion-message.test.js +++ b/test/unit/llm-events/openai/chat-completion-message.test.js @@ -11,8 +11,6 @@ const helper = require('../../../lib/agent_helper') const { req, chatRes, getExpectedResult } = require('./common') tap.test('LlmChatCompletionMessage', (t) => { - t.autoend() - let agent t.beforeEach(() => { agent = helper.loadMockedAgent() @@ -104,4 +102,115 @@ tap.test('LlmChatCompletionMessage', (t) => { t.end() }) }) + + t.test('should use token_count from tokenCountCallback for prompt message', (t) => { + const api = helper.getAgentApi() + const expectedCount = 4 + function cb(model, content) { + t.equal(model, 'gpt-3.5-turbo-0613') + t.equal(content, 'What is a woodchuck?') + return expectedCount + } + api.setLlmTokenCountCallback(cb) + helper.runInTransaction(agent, () => { + api.startSegment('fakeSegment', false, () => { + const segment = api.shim.getActiveSegment() + const summaryId = 'chat-summary-id' + delete chatRes.usage + const chatMessageEvent = new LlmChatCompletionMessage({ + agent, + segment, + request: req, + response: chatRes, + completionId: summaryId, + message: req.messages[0], + index: 0 + }) + t.equal(chatMessageEvent.token_count, expectedCount) + t.end() + }) + }) + }) + + t.test('should use token_count from tokenCountCallback for completion messages', (t) => { + const api = helper.getAgentApi() + const expectedCount = 4 + function cb(model, content) { + t.equal(model, 'gpt-3.5-turbo-0613') + t.equal(content, 'a lot') + return expectedCount + } + api.setLlmTokenCountCallback(cb) + helper.runInTransaction(agent, () => { + api.startSegment('fakeSegment', false, () => { + const segment = api.shim.getActiveSegment() + const summaryId = 'chat-summary-id' + delete chatRes.usage + const chatMessageEvent = new LlmChatCompletionMessage({ + agent, + segment, + request: req, + response: chatRes, + completionId: summaryId, + message: chatRes.choices[0].message, + index: 2 + }) + t.equal(chatMessageEvent.token_count, expectedCount) + t.end() + }) + }) + }) + + t.test('should not set token_count if not set in usage nor a callback registered', (t) => { + const api = helper.getAgentApi() + helper.runInTransaction(agent, () => { + api.startSegment('fakeSegment', false, () => { + const segment = api.shim.getActiveSegment() + const summaryId = 'chat-summary-id' + delete chatRes.usage + const chatMessageEvent = new LlmChatCompletionMessage({ + agent, + segment, + request: req, + response: chatRes, + completionId: summaryId, + message: chatRes.choices[0].message, + index: 2 + }) + t.equal(chatMessageEvent.token_count, undefined) + t.end() + }) + }) + }) + + t.test( + 'should not set token_count if not set in usage nor a callback registered returns count', + (t) => { + const api = helper.getAgentApi() + function cb() { + // empty cb + } + api.setLlmTokenCountCallback(cb) + helper.runInTransaction(agent, () => { + api.startSegment('fakeSegment', false, () => { + const segment = api.shim.getActiveSegment() + const summaryId = 'chat-summary-id' + delete chatRes.usage + const chatMessageEvent = new LlmChatCompletionMessage({ + agent, + segment, + request: req, + response: chatRes, + completionId: summaryId, + message: chatRes.choices[0].message, + index: 2 + }) + t.equal(chatMessageEvent.token_count, undefined) + t.end() + }) + }) + } + ) + + t.end() }) diff --git a/test/unit/llm-events/openai/embedding.test.js b/test/unit/llm-events/openai/embedding.test.js index bbe7ab7b36..7175b072ff 100644 --- a/test/unit/llm-events/openai/embedding.test.js +++ b/test/unit/llm-events/openai/embedding.test.js @@ -113,4 +113,54 @@ tap.test('LlmEmbedding', (t) => { t.end() }) }) + + t.test('should calculate token count from tokenCountCallback', (t) => { + const req = { + input: 'This is my test input', + model: 'gpt-3.5-turbo-0613' + } + + const api = helper.getAgentApi() + + function cb(model, content) { + if (model === req.model) { + return content.length + } + } + + api.setLlmTokenCountCallback(cb) + helper.runInTransaction(agent, () => { + const segment = api.shim.getActiveSegment() + delete res.usage + const embeddingEvent = new LlmEmbedding({ + agent, + segment, + request: req, + response: res + }) + t.equal(embeddingEvent.token_count, 21) + t.end() + }) + }) + + t.test('should not set token count when not present in usage nor tokenCountCallback', (t) => { + const req = { + input: 'This is my test input', + model: 'gpt-3.5-turbo-0613' + } + + const api = helper.getAgentApi() + helper.runInTransaction(agent, () => { + const segment = api.shim.getActiveSegment() + delete res.usage + const embeddingEvent = new LlmEmbedding({ + agent, + segment, + request: req, + response: res + }) + t.equal(embeddingEvent.token_count, undefined) + t.end() + }) + }) }) diff --git a/test/versioned/openai/chat-completions.tap.js b/test/versioned/openai/chat-completions.tap.js index 71493529db..efdbc63790 100644 --- a/test/versioned/openai/chat-completions.tap.js +++ b/test/versioned/openai/chat-completions.tap.js @@ -190,6 +190,58 @@ tap.test('OpenAI instrumentation - chat completions', (t) => { } ) + t.test('should call the tokenCountCallback in streaming', (test) => { + const { client, agent } = t.context + const promptContent = 'Streamed response' + const promptContent2 = 'What does 1 plus 1 equal?' + let res = '' + const expectedModel = 'gpt-4' + const api = helper.getAgentApi() + function cb(model, content) { + t.equal(model, expectedModel) + if (content === promptContent || content === promptContent2) { + return 53 + } else if (content === res) { + return 11 + } + } + api.setLlmTokenCountCallback(cb) + test.teardown(() => { + delete agent.llm.tokenCountCallback + }) + helper.runInTransaction(agent, async (tx) => { + const stream = await client.chat.completions.create({ + max_tokens: 100, + temperature: 0.5, + model: expectedModel, + messages: [ + { role: 'user', content: promptContent }, + { role: 'user', content: promptContent2 } + ], + stream: true + }) + + for await (const chunk of stream) { + res += chunk.choices[0]?.delta?.content + } + + const events = agent.customEventAggregator.events.toArray() + const chatMsgs = events.filter(([{ type }]) => type === 'LlmChatCompletionMessage') + test.llmMessages({ + tokenUsage: true, + tx, + chatMsgs, + id: 'chatcmpl-8MzOfSMbLxEy70lYAolSwdCzfguQZ', + model: expectedModel, + resContent: res, + reqContent: promptContent + }) + + tx.end() + test.end() + }) + }) + t.test('handles error in stream', (test) => { const { client, agent } = t.context helper.runInTransaction(agent, async (tx) => {