Skip to content

Commit

Permalink
feat: Added setLlmTokenCountCallback API endpoint to register a callb…
Browse files Browse the repository at this point in the history
…ack for calculating token count when none is provided (newrelic#2065)
  • Loading branch information
bizob2828 authored Mar 6, 2024
1 parent 793abe8 commit d2faf1a
Show file tree
Hide file tree
Showing 11 changed files with 322 additions and 11 deletions.
35 changes: 34 additions & 1 deletion api.js
Original file line number Diff line number Diff line change
Expand Up @@ -1839,7 +1839,7 @@ API.prototype.setErrorGroupCallback = function setErrorGroupCallback(callback) {
)
metric.incrementCallCount()

if (!this.shim.isFunction(callback) || this.shim.isPromise(callback)) {
if (!this.shim.isFunction(callback) || this.shim.isAsyncFunction(callback)) {
logger.warn(
'Error Group callback must be a synchronous function, Error Group attribute will not be added'
)
Expand All @@ -1849,4 +1849,37 @@ API.prototype.setErrorGroupCallback = function setErrorGroupCallback(callback) {
this.agent.errors.errorGroupCallback = callback
}

/**
* Registers a callback which will be used for calculating token counts on Llm events when they are not
* available. This function will typically only be used if `ai_monitoring.record_content.enabled` is false
* and you want to still capture token counts for Llm events.
*
* Provided callbacks must return an integer value for the token count for a given piece of content.
*
* @param {Function} callback - synchronous function called to calculate token count for content.
* @example
* // @param {string} model - name of model (i.e. gpt-3.5-turbo)
* // @param {string} content - prompt or completion response
* function tokenCallback(model, content) {
* // calculate tokens based on model and content
* // return token count
* return 40
* }
*/
API.prototype.setLlmTokenCountCallback = function setLlmTokenCountCallback(callback) {
const metric = this.agent.metrics.getOrCreateMetric(
NAMES.SUPPORTABILITY.API + '/setLlmTokenCountCallback'
)
metric.incrementCallCount()

if (!this.shim.isFunction(callback) || this.shim.isAsyncFunction(callback)) {
logger.warn(
'Llm token count callback must be a synchronous function, callback will not be registered.'
)
return
}

this.agent.llm.tokenCountCallback = callback
}

module.exports = API
2 changes: 1 addition & 1 deletion lib/instrumentation/restify.js
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ function wrapMiddleware(shim, middleware, _name, route) {
})

const wrappedMw = shim.recordMiddleware(middleware, spec)
if (middleware.constructor.name === 'AsyncFunction') {
if (shim.isAsyncFunction(middleware)) {
return async function asyncShim() {
return wrappedMw.apply(this, arguments)
}
Expand Down
8 changes: 6 additions & 2 deletions lib/llm-events/openai/chat-completion-message.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@ module.exports = class LlmChatCompletionMessage extends LlmEvent {
}

if (this.is_response) {
this.token_count = response?.usage?.completion_tokens
this.token_count =
response?.usage?.completion_tokens ||
agent.llm?.tokenCountCallback?.(this['response.model'], message?.content)
} else {
this.token_count = response?.usage?.prompt_tokens
this.token_count =
response?.usage?.prompt_tokens ||
agent.llm?.tokenCountCallback?.(request.model || request.engine, message?.content)
}
}
}
4 changes: 3 additions & 1 deletion lib/llm-events/openai/embedding.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ module.exports = class LlmEmbedding extends LlmEvent {
if (agent.config.ai_monitoring.record_content.enabled === true) {
this.input = request.input?.toString()
}
this.token_count = response?.usage?.prompt_tokens
this.token_count =
response?.usage?.prompt_tokens ||
agent.llm?.tokenCountCallback?.(this['request.model'], request.input?.toString())
}
}
15 changes: 15 additions & 0 deletions lib/shim/shim.js
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ Shim.prototype.getName = getName
Shim.prototype.isObject = isObject
Shim.prototype.isFunction = isFunction
Shim.prototype.isPromise = isPromise
Shim.prototype.isAsyncFunction = isAsyncFunction
Shim.prototype.isString = isString
Shim.prototype.isNumber = isNumber
Shim.prototype.isBoolean = isBoolean
Expand Down Expand Up @@ -1345,6 +1346,20 @@ function isPromise(obj) {
return obj && typeof obj.then === 'function'
}

/**
* Determines if function is an async function.
* Note it does not test if the return value of function is a
* promise or async function
*
* @memberof Shim.prototype
* @param fn
* @param (function) function to test if async
* @returns {boolean} True if the function is an async function
*/
function isAsyncFunction(fn) {
return fn.constructor.name === 'AsyncFunction'
}

/**
* Determines if the given value is null.
*
Expand Down
48 changes: 47 additions & 1 deletion test/unit/api/api-llm.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ tap.test('Agent API LLM methods', (t) => {
loggerMock.warn.reset()
const agent = helper.loadMockedAgent()
t.context.api = new API(agent)
t.context.api.agent.config.ai_monitoring.enabled = true
agent.config.ai_monitoring.enabled = true
t.context.agent = agent
})

t.afterEach((t) => {
Expand Down Expand Up @@ -119,4 +120,49 @@ tap.test('Agent API LLM methods', (t) => {
})
})
})

t.test('setLlmTokenCount should register callback to calculate token counts', async (t) => {
const { api, agent } = t.context
function callback(model, content) {
if (model === 'foo' && content === 'bar') {
return 10
}

return 1
}
api.setLlmTokenCountCallback(callback)
t.same(agent.llm.tokenCountCallback, callback)
})

t.test('should not store token count callback if it is async', async (t) => {
const { api, agent } = t.context
async function callback(model, content) {
return await new Promise((resolve) => {
if (model === 'foo' && content === 'bar') {
resolve(10)
}
})
}
api.setLlmTokenCountCallback(callback)
t.same(agent.llm.tokenCountCallback, undefined)
t.equal(loggerMock.warn.callCount, 1)
t.equal(
loggerMock.warn.args[0][0],
'Llm token count callback must be a synchronous function, callback will not be registered.'
)
})

t.test(
'should not store token count callback if callback is not actually a function',
async (t) => {
const { api, agent } = t.context
api.setLlmTokenCountCallback({ unit: 'test' })
t.same(agent.llm.tokenCountCallback, undefined)
t.equal(loggerMock.warn.callCount, 1)
t.equal(
loggerMock.warn.args[0][0],
'Llm token count callback must be a synchronous function, callback will not be registered.'
)
}
)
})
4 changes: 2 additions & 2 deletions test/unit/api/api-set-error-group-callback.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ tap.test('Agent API = set Error Group callback', (t) => {
})

t.test('should not attach the callback when async function', (t) => {
function callback() {
return new Promise((resolve) => {
async function callback() {
return await new Promise((resolve) => {
setTimeout(() => {
resolve()
}, 200)
Expand Down
2 changes: 1 addition & 1 deletion test/unit/api/stub.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
const tap = require('tap')
const API = require('../../../stub_api')

const EXPECTED_API_COUNT = 34
const EXPECTED_API_COUNT = 35

tap.test('Agent API - Stubbed Agent API', (t) => {
t.autoend()
Expand Down
113 changes: 111 additions & 2 deletions test/unit/llm-events/openai/chat-completion-message.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ const helper = require('../../../lib/agent_helper')
const { req, chatRes, getExpectedResult } = require('./common')

tap.test('LlmChatCompletionMessage', (t) => {
t.autoend()

let agent
t.beforeEach(() => {
agent = helper.loadMockedAgent()
Expand Down Expand Up @@ -104,4 +102,115 @@ tap.test('LlmChatCompletionMessage', (t) => {
t.end()
})
})

t.test('should use token_count from tokenCountCallback for prompt message', (t) => {
const api = helper.getAgentApi()
const expectedCount = 4
function cb(model, content) {
t.equal(model, 'gpt-3.5-turbo-0613')
t.equal(content, 'What is a woodchuck?')
return expectedCount
}
api.setLlmTokenCountCallback(cb)
helper.runInTransaction(agent, () => {
api.startSegment('fakeSegment', false, () => {
const segment = api.shim.getActiveSegment()
const summaryId = 'chat-summary-id'
delete chatRes.usage
const chatMessageEvent = new LlmChatCompletionMessage({
agent,
segment,
request: req,
response: chatRes,
completionId: summaryId,
message: req.messages[0],
index: 0
})
t.equal(chatMessageEvent.token_count, expectedCount)
t.end()
})
})
})

t.test('should use token_count from tokenCountCallback for completion messages', (t) => {
const api = helper.getAgentApi()
const expectedCount = 4
function cb(model, content) {
t.equal(model, 'gpt-3.5-turbo-0613')
t.equal(content, 'a lot')
return expectedCount
}
api.setLlmTokenCountCallback(cb)
helper.runInTransaction(agent, () => {
api.startSegment('fakeSegment', false, () => {
const segment = api.shim.getActiveSegment()
const summaryId = 'chat-summary-id'
delete chatRes.usage
const chatMessageEvent = new LlmChatCompletionMessage({
agent,
segment,
request: req,
response: chatRes,
completionId: summaryId,
message: chatRes.choices[0].message,
index: 2
})
t.equal(chatMessageEvent.token_count, expectedCount)
t.end()
})
})
})

t.test('should not set token_count if not set in usage nor a callback registered', (t) => {
const api = helper.getAgentApi()
helper.runInTransaction(agent, () => {
api.startSegment('fakeSegment', false, () => {
const segment = api.shim.getActiveSegment()
const summaryId = 'chat-summary-id'
delete chatRes.usage
const chatMessageEvent = new LlmChatCompletionMessage({
agent,
segment,
request: req,
response: chatRes,
completionId: summaryId,
message: chatRes.choices[0].message,
index: 2
})
t.equal(chatMessageEvent.token_count, undefined)
t.end()
})
})
})

t.test(
'should not set token_count if not set in usage nor a callback registered returns count',
(t) => {
const api = helper.getAgentApi()
function cb() {
// empty cb
}
api.setLlmTokenCountCallback(cb)
helper.runInTransaction(agent, () => {
api.startSegment('fakeSegment', false, () => {
const segment = api.shim.getActiveSegment()
const summaryId = 'chat-summary-id'
delete chatRes.usage
const chatMessageEvent = new LlmChatCompletionMessage({
agent,
segment,
request: req,
response: chatRes,
completionId: summaryId,
message: chatRes.choices[0].message,
index: 2
})
t.equal(chatMessageEvent.token_count, undefined)
t.end()
})
})
}
)

t.end()
})
50 changes: 50 additions & 0 deletions test/unit/llm-events/openai/embedding.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,54 @@ tap.test('LlmEmbedding', (t) => {
t.end()
})
})

t.test('should calculate token count from tokenCountCallback', (t) => {
const req = {
input: 'This is my test input',
model: 'gpt-3.5-turbo-0613'
}

const api = helper.getAgentApi()

function cb(model, content) {
if (model === req.model) {
return content.length
}
}

api.setLlmTokenCountCallback(cb)
helper.runInTransaction(agent, () => {
const segment = api.shim.getActiveSegment()
delete res.usage
const embeddingEvent = new LlmEmbedding({
agent,
segment,
request: req,
response: res
})
t.equal(embeddingEvent.token_count, 21)
t.end()
})
})

t.test('should not set token count when not present in usage nor tokenCountCallback', (t) => {
const req = {
input: 'This is my test input',
model: 'gpt-3.5-turbo-0613'
}

const api = helper.getAgentApi()
helper.runInTransaction(agent, () => {
const segment = api.shim.getActiveSegment()
delete res.usage
const embeddingEvent = new LlmEmbedding({
agent,
segment,
request: req,
response: res
})
t.equal(embeddingEvent.token_count, undefined)
t.end()
})
})
})
Loading

0 comments on commit d2faf1a

Please sign in to comment.