Skip to content

Commit

Permalink
Feature/Add tool choices to openai assistant (FlowiseAI#2682)
Browse files Browse the repository at this point in the history
add tool choices to openai
  • Loading branch information
HenryHengZJ authored Jun 19, 2024
1 parent 8bb8416 commit 72e5287
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 50 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,6 @@
"resolutions": {
"@qdrant/openapi-typescript-fetch": "1.2.1",
"@google/generative-ai": "^0.7.0",
"openai": "4.38.3"
"openai": "4.51.0"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class OpenAIAssistant_Agents implements INode {
constructor() {
this.label = 'OpenAI Assistant'
this.name = 'openAIAssistant'
this.version = 3.0
this.version = 4.0
this.type = 'OpenAIAssistant'
this.category = 'Agents'
this.icon = 'assistant.svg'
Expand All @@ -54,6 +54,25 @@ class OpenAIAssistant_Agents implements INode {
optional: true,
list: true
},
{
label: 'Tool Choice',
name: 'toolChoice',
type: 'string',
description:
'Controls which (if any) tool is called by the model. Can be "none", "auto", "required", or the name of a tool. Refer <a href="https://platform.openai.com/docs/api-reference/runs/createRun#runs-createrun-tool_choice" target="_blank">here</a> for more information',
placeholder: 'file_search',
optional: true,
additionalParams: true
},
{
label: 'Parallel Tool Calls',
name: 'parallelToolCalls',
type: 'boolean',
description: 'Whether to enable parallel function calling during tool use. Defaults to true',
default: true,
optional: true,
additionalParams: true
},
{
label: 'Disable File Download',
name: 'disableFileDownload',
Expand Down Expand Up @@ -155,6 +174,8 @@ class OpenAIAssistant_Agents implements INode {
const databaseEntities = options.databaseEntities as IDatabaseEntity
const disableFileDownload = nodeData.inputs?.disableFileDownload as boolean
const moderations = nodeData.inputs?.inputModeration as Moderation[]
const _toolChoice = nodeData.inputs?.toolChoice as string
const parallelToolCalls = nodeData.inputs?.parallelToolCalls as boolean
const isStreaming = options.socketIO && options.socketIOClientId
const socketIO = isStreaming ? options.socketIO : undefined
const socketIOClientId = isStreaming ? options.socketIOClientId : ''
Expand Down Expand Up @@ -273,10 +294,25 @@ class OpenAIAssistant_Agents implements INode {
let runThreadId = ''
let isStreamingStarted = false

let toolChoice: any
if (_toolChoice) {
if (_toolChoice === 'file_search') {
toolChoice = { type: 'file_search' }
} else if (_toolChoice === 'code_interpreter') {
toolChoice = { type: 'code_interpreter' }
} else if (_toolChoice === 'none' || _toolChoice === 'auto' || _toolChoice === 'required') {
toolChoice = _toolChoice
} else {
toolChoice = { type: 'function', function: { name: _toolChoice } }
}
}

if (isStreaming) {
const streamThread = await openai.beta.threads.runs.create(threadId, {
assistant_id: retrievedAssistant.id,
stream: true
stream: true,
tool_choice: toolChoice,
parallel_tool_calls: parallelToolCalls
})

for await (const event of streamThread) {
Expand Down Expand Up @@ -599,7 +635,9 @@ class OpenAIAssistant_Agents implements INode {

// Polling run status
const runThread = await openai.beta.threads.runs.create(threadId, {
assistant_id: retrievedAssistant.id
assistant_id: retrievedAssistant.id,
tool_choice: toolChoice,
parallel_tool_calls: parallelToolCalls
})
runThreadId = runThread.id
let state = await promise(threadId, runThread.id)
Expand All @@ -612,7 +650,9 @@ class OpenAIAssistant_Agents implements INode {
if (retries > 0) {
retries -= 1
const newRunThread = await openai.beta.threads.runs.create(threadId, {
assistant_id: retrievedAssistant.id
assistant_id: retrievedAssistant.id,
tool_choice: toolChoice,
parallel_tool_calls: parallelToolCalls
})
runThreadId = newRunThread.id
state = await promise(threadId, newRunThread.id)
Expand Down
2 changes: 1 addition & 1 deletion packages/components/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
"node-html-markdown": "^1.3.0",
"notion-to-md": "^3.1.1",
"object-hash": "^3.0.0",
"openai": "^4.38.3",
"openai": "^4.51.0",
"pdf-parse": "^1.1.1",
"pdfjs-dist": "^3.7.107",
"pg": "^8.11.2",
Expand Down
2 changes: 1 addition & 1 deletion packages/server/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"moment-timezone": "^0.5.34",
"multer": "^1.4.5-lts.1",
"mysql2": "^3.9.2",
"openai": "^4.20.0",
"openai": "^4.51.0",
"pg": "^8.11.1",
"posthog-node": "^3.5.0",
"reflect-metadata": "^0.1.13",
Expand Down
Loading

0 comments on commit 72e5287

Please sign in to comment.