-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds tool calling output parser (#3232)
- Loading branch information
1 parent
0cb640f
commit 3289345
Showing
6 changed files
with
172 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import { z } from "zod"; | ||
import { zodToJsonSchema } from "zod-to-json-schema"; | ||
import { ChatPromptTemplate } from "langchain/prompts"; | ||
import { ChatOpenAI } from "langchain/chat_models/openai"; | ||
import { JsonOutputToolsParser } from "langchain/output_parsers"; | ||
|
||
const EXTRACTION_TEMPLATE = `Extract and save the relevant entities mentioned \ | ||
in the following passage together with their properties. | ||
If a property is not present and is not required in the function parameters, do not include it in the output.`; | ||
|
||
const prompt = ChatPromptTemplate.fromMessages([ | ||
["system", EXTRACTION_TEMPLATE], | ||
["human", "{input}"], | ||
]); | ||
|
||
const person = z.object({ | ||
name: z.string().describe("The person's name"), | ||
age: z.string().describe("The person's age"), | ||
}); | ||
|
||
const model = new ChatOpenAI({ | ||
modelName: "gpt-3.5-turbo-1106", | ||
temperature: 0, | ||
}).bind({ | ||
tools: [ | ||
{ | ||
type: "function", | ||
function: { | ||
name: "person", | ||
description: "A person", | ||
parameters: zodToJsonSchema(person), | ||
}, | ||
}, | ||
], | ||
}); | ||
|
||
const parser = new JsonOutputToolsParser(); | ||
const chain = prompt.pipe(model).pipe(parser); | ||
|
||
const res = await chain.invoke({ | ||
input: "jane is 2 and bob is 3", | ||
}); | ||
|
||
console.log(res); | ||
/* | ||
[ | ||
{ name: 'person', arguments: { name: 'jane', age: '2' } }, | ||
{ name: 'person', arguments: { name: 'bob', age: '3' } } | ||
] | ||
*/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import { BaseLLMOutputParser } from "../schema/output_parser.js"; | ||
import type { ChatGeneration } from "../schema/index.js"; | ||
|
||
export type ParsedToolCall = { | ||
name: string; | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
arguments: Record<string, any>; | ||
}; | ||
|
||
/** | ||
* Class for parsing the output of an LLM into a JSON object. Uses an | ||
* instance of `OutputToolsParser` to parse the output. | ||
*/ | ||
export class JsonOutputToolsParser extends BaseLLMOutputParser< | ||
ParsedToolCall[] | ||
> { | ||
static lc_name() { | ||
return "JsonOutputToolsParser"; | ||
} | ||
|
||
lc_namespace = ["langchain", "output_parsers"]; | ||
|
||
lc_serializable = true; | ||
|
||
/** | ||
* Parses the output and returns a JSON object. If `argsOnly` is true, | ||
* only the arguments of the function call are returned. | ||
* @param generations The output of the LLM to parse. | ||
* @returns A JSON object representation of the function call or its arguments. | ||
*/ | ||
async parseResult(generations: ChatGeneration[]): Promise<ParsedToolCall[]> { | ||
const toolCalls = generations[0].message.additional_kwargs.tool_calls; | ||
if (!toolCalls) { | ||
throw new Error( | ||
`No tools_call in message ${JSON.stringify(generations)}` | ||
); | ||
} | ||
const clonedToolCalls = JSON.parse(JSON.stringify(toolCalls)); | ||
const parsedToolCalls = []; | ||
for (const toolCall of clonedToolCalls) { | ||
if (toolCall.function !== undefined) { | ||
const functionArgs = toolCall.function.arguments; | ||
parsedToolCalls.push({ | ||
name: toolCall.function.name, | ||
arguments: JSON.parse(functionArgs), | ||
}); | ||
} | ||
} | ||
return parsedToolCalls; | ||
} | ||
} |
45 changes: 45 additions & 0 deletions
45
langchain/src/output_parsers/tests/openai_tools.int.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
/* eslint-disable @typescript-eslint/no-explicit-any */ | ||
|
||
import { expect, test } from "@jest/globals"; | ||
import { z } from "zod"; | ||
import { zodToJsonSchema } from "zod-to-json-schema"; | ||
|
||
import { ChatOpenAI } from "../../chat_models/openai.js"; | ||
import { ChatPromptTemplate } from "../../prompts/index.js"; | ||
import { JsonOutputToolsParser } from "../openai_tools.js"; | ||
|
||
const schema = z.object({ | ||
setup: z.string().describe("The setup for the joke"), | ||
punchline: z.string().describe("The punchline to the joke"), | ||
}); | ||
|
||
test("Extraction", async () => { | ||
const prompt = ChatPromptTemplate.fromTemplate( | ||
`tell me two jokes about {foo}` | ||
); | ||
const model = new ChatOpenAI({ | ||
modelName: "gpt-3.5-turbo-1106", | ||
temperature: 0, | ||
}).bind({ | ||
tools: [ | ||
{ | ||
type: "function", | ||
function: { | ||
name: "joke", | ||
description: "A joke", | ||
parameters: zodToJsonSchema(schema), | ||
}, | ||
}, | ||
], | ||
}); | ||
|
||
const parser = new JsonOutputToolsParser(); | ||
const chain = prompt.pipe(model).pipe(parser); | ||
|
||
const res = await chain.invoke({ | ||
foo: "bears", | ||
}); | ||
|
||
console.log(res); | ||
expect(res.length).toBe(2); | ||
}); |