forked from FlowiseAI/Flowise
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature/Mistral FunctionAgent (FlowiseAI#1912)
* add mistral ai agent, add used tools streaming * fix AWS Bedrock imports * update pnpm lock
- Loading branch information
1 parent
58122e9
commit cd4c659
Showing
13 changed files
with
30,549 additions
and
29,820 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1 change: 1 addition & 0 deletions
1
packages/components/nodes/agents/MistralAIFunctionAgent/MistralAI.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
207 changes: 207 additions & 0 deletions
207
packages/components/nodes/agents/MistralAIFunctionAgent/MistralAIFunctionAgent.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
import { flatten } from 'lodash' | ||
import { BaseMessage } from '@langchain/core/messages' | ||
import { ChainValues } from '@langchain/core/utils/types' | ||
import { AgentStep } from '@langchain/core/agents' | ||
import { RunnableSequence } from '@langchain/core/runnables' | ||
import { ChatOpenAI } from '@langchain/openai' | ||
import { convertToOpenAITool } from '@langchain/core/utils/function_calling' | ||
import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts' | ||
import { OpenAIToolsAgentOutputParser } from 'langchain/agents/openai/output_parser' | ||
import { getBaseClasses } from '../../../src/utils' | ||
import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, IUsedTool } from '../../../src/Interface' | ||
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' | ||
import { AgentExecutor, formatAgentSteps } from '../../../src/agents' | ||
import { Moderation, checkInputs, streamResponse } from '../../moderation/Moderation' | ||
import { formatResponse } from '../../outputparsers/OutputParserHelpers' | ||
|
||
class MistralAIFunctionAgent_Agents implements INode { | ||
label: string | ||
name: string | ||
version: number | ||
description: string | ||
type: string | ||
icon: string | ||
category: string | ||
baseClasses: string[] | ||
inputs: INodeParams[] | ||
sessionId?: string | ||
badge?: string | ||
|
||
constructor(fields?: { sessionId?: string }) { | ||
this.label = 'MistralAI Function Agent' | ||
this.name = 'mistralAIFunctionAgent' | ||
this.version = 1.0 | ||
this.type = 'AgentExecutor' | ||
this.category = 'Agents' | ||
this.icon = 'MistralAI.svg' | ||
this.badge = 'NEW' | ||
this.description = `An agent that uses MistralAI Function Calling to pick the tool and args to call` | ||
this.baseClasses = [this.type, ...getBaseClasses(AgentExecutor)] | ||
this.inputs = [ | ||
{ | ||
label: 'Tools', | ||
name: 'tools', | ||
type: 'Tool', | ||
list: true | ||
}, | ||
{ | ||
label: 'Memory', | ||
name: 'memory', | ||
type: 'BaseChatMemory' | ||
}, | ||
{ | ||
label: 'MistralAI Chat Model', | ||
name: 'model', | ||
type: 'BaseChatModel' | ||
}, | ||
{ | ||
label: 'System Message', | ||
name: 'systemMessage', | ||
type: 'string', | ||
rows: 4, | ||
optional: true, | ||
additionalParams: true | ||
}, | ||
{ | ||
label: 'Input Moderation', | ||
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model', | ||
name: 'inputModeration', | ||
type: 'Moderation', | ||
optional: true, | ||
list: true | ||
} | ||
] | ||
this.sessionId = fields?.sessionId | ||
} | ||
|
||
async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> { | ||
return prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory) | ||
} | ||
|
||
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> { | ||
const memory = nodeData.inputs?.memory as FlowiseMemory | ||
const moderations = nodeData.inputs?.inputModeration as Moderation[] | ||
|
||
if (moderations && moderations.length > 0) { | ||
try { | ||
// Use the output of the moderation chain as input for the OpenAI Function Agent | ||
input = await checkInputs(moderations, input) | ||
} catch (e) { | ||
await new Promise((resolve) => setTimeout(resolve, 500)) | ||
streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId) | ||
return formatResponse(e.message) | ||
} | ||
} | ||
|
||
const executor = prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory) | ||
|
||
const loggerHandler = new ConsoleCallbackHandler(options.logger) | ||
const callbacks = await additionalCallbacks(nodeData, options) | ||
|
||
let res: ChainValues = {} | ||
let sourceDocuments: ICommonObject[] = [] | ||
let usedTools: IUsedTool[] = [] | ||
|
||
if (options.socketIO && options.socketIOClientId) { | ||
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId) | ||
res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] }) | ||
if (res.sourceDocuments) { | ||
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments)) | ||
sourceDocuments = res.sourceDocuments | ||
} | ||
if (res.usedTools) { | ||
options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools) | ||
usedTools = res.usedTools | ||
} | ||
} else { | ||
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] }) | ||
if (res.sourceDocuments) { | ||
sourceDocuments = res.sourceDocuments | ||
} | ||
if (res.usedTools) { | ||
usedTools = res.usedTools | ||
} | ||
} | ||
|
||
await memory.addChatMessages( | ||
[ | ||
{ | ||
text: input, | ||
type: 'userMessage' | ||
}, | ||
{ | ||
text: res?.output, | ||
type: 'apiMessage' | ||
} | ||
], | ||
this.sessionId | ||
) | ||
|
||
let finalRes = res?.output | ||
|
||
if (sourceDocuments.length || usedTools.length) { | ||
finalRes = { text: res?.output } | ||
if (sourceDocuments.length) { | ||
finalRes.sourceDocuments = flatten(sourceDocuments) | ||
} | ||
if (usedTools.length) { | ||
finalRes.usedTools = usedTools | ||
} | ||
return finalRes | ||
} | ||
|
||
return finalRes | ||
} | ||
} | ||
|
||
const prepareAgent = ( | ||
nodeData: INodeData, | ||
flowObj: { sessionId?: string; chatId?: string; input?: string }, | ||
chatHistory: IMessage[] = [] | ||
) => { | ||
const model = nodeData.inputs?.model as ChatOpenAI | ||
const memory = nodeData.inputs?.memory as FlowiseMemory | ||
const systemMessage = nodeData.inputs?.systemMessage as string | ||
let tools = nodeData.inputs?.tools | ||
tools = flatten(tools) | ||
const memoryKey = memory.memoryKey ? memory.memoryKey : 'chat_history' | ||
const inputKey = memory.inputKey ? memory.inputKey : 'input' | ||
|
||
const prompt = ChatPromptTemplate.fromMessages([ | ||
['system', systemMessage ? systemMessage : `You are a helpful AI assistant.`], | ||
new MessagesPlaceholder(memoryKey), | ||
['human', `{${inputKey}}`], | ||
new MessagesPlaceholder('agent_scratchpad') | ||
]) | ||
|
||
const llmWithTools = model.bind({ | ||
tools: tools.map(convertToOpenAITool) | ||
}) | ||
|
||
const runnableAgent = RunnableSequence.from([ | ||
{ | ||
[inputKey]: (i: { input: string; steps: AgentStep[] }) => i.input, | ||
agent_scratchpad: (i: { input: string; steps: AgentStep[] }) => formatAgentSteps(i.steps), | ||
[memoryKey]: async (_: { input: string; steps: AgentStep[] }) => { | ||
const messages = (await memory.getChatMessages(flowObj?.sessionId, true, chatHistory)) as BaseMessage[] | ||
return messages ?? [] | ||
} | ||
}, | ||
prompt, | ||
llmWithTools, | ||
new OpenAIToolsAgentOutputParser() | ||
]) | ||
|
||
const executor = AgentExecutor.fromAgentAndTools({ | ||
agent: runnableAgent, | ||
tools, | ||
sessionId: flowObj?.sessionId, | ||
chatId: flowObj?.chatId, | ||
input: flowObj?.input, | ||
verbose: process.env.DEBUG === 'true' ? true : false | ||
}) | ||
|
||
return executor | ||
} | ||
|
||
module.exports = { nodeClass: MistralAIFunctionAgent_Agents } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.