Skip to content

Commit

Permalink
Feature/Mistral FunctionAgent (FlowiseAI#1912)
Browse files Browse the repository at this point in the history
* add mistral ai agent, add used tools streaming

* fix AWS Bedrock imports

* update pnpm lock
  • Loading branch information
HenryHengZJ authored Mar 18, 2024
1 parent 58122e9 commit cd4c659
Show file tree
Hide file tree
Showing 13 changed files with 30,549 additions and 29,820 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { RunnableSequence } from '@langchain/core/runnables'
import { ChatConversationalAgent } from 'langchain/agents'
import { getBaseClasses } from '../../../src/utils'
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
import { IVisionChatModal, FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
import { IVisionChatModal, FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, IUsedTool } from '../../../src/Interface'
import { AgentExecutor } from '../../../src/agents'
import { addImagesToMessages, llmSupportsVision } from '../../../src/multiModalUtils'
import { checkInputs, Moderation } from '../../moderation/Moderation'
Expand Down Expand Up @@ -120,12 +120,28 @@ class ConversationalAgent_Agents implements INode {
const callbacks = await additionalCallbacks(nodeData, options)

let res: ChainValues = {}
let sourceDocuments: ICommonObject[] = []
let usedTools: IUsedTool[] = []

if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] })
if (res.sourceDocuments) {
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments))
sourceDocuments = res.sourceDocuments
}
if (res.usedTools) {
options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools)
usedTools = res.usedTools
}
} else {
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
if (res.sourceDocuments) {
sourceDocuments = res.sourceDocuments
}
if (res.usedTools) {
usedTools = res.usedTools
}
}

await memory.addChatMessages(
Expand All @@ -142,7 +158,20 @@ class ConversationalAgent_Agents implements INode {
this.sessionId
)

return res?.output
let finalRes = res?.output

if (sourceDocuments.length || usedTools.length) {
finalRes = { text: res?.output }
if (sourceDocuments.length) {
finalRes.sourceDocuments = flatten(sourceDocuments)
}
if (usedTools.length) {
finalRes.usedTools = usedTools
}
return finalRes
}

return finalRes
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class ConversationalRetrievalAgent_Agents implements INode {
category: string
baseClasses: string[]
inputs: INodeParams[]
badge?: string
sessionId?: string

constructor(fields?: { sessionId?: string }) {
Expand All @@ -33,6 +34,7 @@ class ConversationalRetrievalAgent_Agents implements INode {
this.version = 4.0
this.type = 'AgentExecutor'
this.category = 'Agents'
this.badge = 'DEPRECATING'
this.icon = 'agent.svg'
this.description = `An agent optimized for retrieval during conversation, answering questions based on past dialogue, all using OpenAI's Function Calling`
this.baseClasses = [this.type, ...getBaseClasses(AgentExecutor)]
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import { flatten } from 'lodash'
import { BaseMessage } from '@langchain/core/messages'
import { ChainValues } from '@langchain/core/utils/types'
import { AgentStep } from '@langchain/core/agents'
import { RunnableSequence } from '@langchain/core/runnables'
import { ChatOpenAI } from '@langchain/openai'
import { convertToOpenAITool } from '@langchain/core/utils/function_calling'
import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'
import { OpenAIToolsAgentOutputParser } from 'langchain/agents/openai/output_parser'
import { getBaseClasses } from '../../../src/utils'
import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, IUsedTool } from '../../../src/Interface'
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
import { AgentExecutor, formatAgentSteps } from '../../../src/agents'
import { Moderation, checkInputs, streamResponse } from '../../moderation/Moderation'
import { formatResponse } from '../../outputparsers/OutputParserHelpers'

class MistralAIFunctionAgent_Agents implements INode {
label: string
name: string
version: number
description: string
type: string
icon: string
category: string
baseClasses: string[]
inputs: INodeParams[]
sessionId?: string
badge?: string

constructor(fields?: { sessionId?: string }) {
this.label = 'MistralAI Function Agent'
this.name = 'mistralAIFunctionAgent'
this.version = 1.0
this.type = 'AgentExecutor'
this.category = 'Agents'
this.icon = 'MistralAI.svg'
this.badge = 'NEW'
this.description = `An agent that uses MistralAI Function Calling to pick the tool and args to call`
this.baseClasses = [this.type, ...getBaseClasses(AgentExecutor)]
this.inputs = [
{
label: 'Tools',
name: 'tools',
type: 'Tool',
list: true
},
{
label: 'Memory',
name: 'memory',
type: 'BaseChatMemory'
},
{
label: 'MistralAI Chat Model',
name: 'model',
type: 'BaseChatModel'
},
{
label: 'System Message',
name: 'systemMessage',
type: 'string',
rows: 4,
optional: true,
additionalParams: true
},
{
label: 'Input Moderation',
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model',
name: 'inputModeration',
type: 'Moderation',
optional: true,
list: true
}
]
this.sessionId = fields?.sessionId
}

async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> {
return prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory)
}

async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
const memory = nodeData.inputs?.memory as FlowiseMemory
const moderations = nodeData.inputs?.inputModeration as Moderation[]

if (moderations && moderations.length > 0) {
try {
// Use the output of the moderation chain as input for the OpenAI Function Agent
input = await checkInputs(moderations, input)
} catch (e) {
await new Promise((resolve) => setTimeout(resolve, 500))
streamResponse(options.socketIO && options.socketIOClientId, e.message, options.socketIO, options.socketIOClientId)
return formatResponse(e.message)
}
}

const executor = prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input }, options.chatHistory)

const loggerHandler = new ConsoleCallbackHandler(options.logger)
const callbacks = await additionalCallbacks(nodeData, options)

let res: ChainValues = {}
let sourceDocuments: ICommonObject[] = []
let usedTools: IUsedTool[] = []

if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
res = await executor.invoke({ input }, { callbacks: [loggerHandler, handler, ...callbacks] })
if (res.sourceDocuments) {
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments))
sourceDocuments = res.sourceDocuments
}
if (res.usedTools) {
options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools)
usedTools = res.usedTools
}
} else {
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
if (res.sourceDocuments) {
sourceDocuments = res.sourceDocuments
}
if (res.usedTools) {
usedTools = res.usedTools
}
}

await memory.addChatMessages(
[
{
text: input,
type: 'userMessage'
},
{
text: res?.output,
type: 'apiMessage'
}
],
this.sessionId
)

let finalRes = res?.output

if (sourceDocuments.length || usedTools.length) {
finalRes = { text: res?.output }
if (sourceDocuments.length) {
finalRes.sourceDocuments = flatten(sourceDocuments)
}
if (usedTools.length) {
finalRes.usedTools = usedTools
}
return finalRes
}

return finalRes
}
}

const prepareAgent = (
nodeData: INodeData,
flowObj: { sessionId?: string; chatId?: string; input?: string },
chatHistory: IMessage[] = []
) => {
const model = nodeData.inputs?.model as ChatOpenAI
const memory = nodeData.inputs?.memory as FlowiseMemory
const systemMessage = nodeData.inputs?.systemMessage as string
let tools = nodeData.inputs?.tools
tools = flatten(tools)
const memoryKey = memory.memoryKey ? memory.memoryKey : 'chat_history'
const inputKey = memory.inputKey ? memory.inputKey : 'input'

const prompt = ChatPromptTemplate.fromMessages([
['system', systemMessage ? systemMessage : `You are a helpful AI assistant.`],
new MessagesPlaceholder(memoryKey),
['human', `{${inputKey}}`],
new MessagesPlaceholder('agent_scratchpad')
])

const llmWithTools = model.bind({
tools: tools.map(convertToOpenAITool)
})

const runnableAgent = RunnableSequence.from([
{
[inputKey]: (i: { input: string; steps: AgentStep[] }) => i.input,
agent_scratchpad: (i: { input: string; steps: AgentStep[] }) => formatAgentSteps(i.steps),
[memoryKey]: async (_: { input: string; steps: AgentStep[] }) => {
const messages = (await memory.getChatMessages(flowObj?.sessionId, true, chatHistory)) as BaseMessage[]
return messages ?? []
}
},
prompt,
llmWithTools,
new OpenAIToolsAgentOutputParser()
])

const executor = AgentExecutor.fromAgentAndTools({
agent: runnableAgent,
tools,
sessionId: flowObj?.sessionId,
chatId: flowObj?.chatId,
input: flowObj?.input,
verbose: process.env.DEBUG === 'true' ? true : false
})

return executor
}

module.exports = { nodeClass: MistralAIFunctionAgent_Agents }
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { ChatOpenAI, formatToOpenAIFunction } from '@langchain/openai'
import { ChatPromptTemplate, MessagesPlaceholder } from '@langchain/core/prompts'
import { OpenAIFunctionsAgentOutputParser } from 'langchain/agents/openai/output_parser'
import { getBaseClasses } from '../../../src/utils'
import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface'
import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams, IUsedTool } from '../../../src/Interface'
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
import { AgentExecutor, formatAgentSteps } from '../../../src/agents'
import { Moderation, checkInputs } from '../../moderation/Moderation'
Expand Down Expand Up @@ -97,6 +97,7 @@ class OpenAIFunctionAgent_Agents implements INode {

let res: ChainValues = {}
let sourceDocuments: ICommonObject[] = []
let usedTools: IUsedTool[] = []

if (options.socketIO && options.socketIOClientId) {
const handler = new CustomChainHandler(options.socketIO, options.socketIOClientId)
Expand All @@ -105,11 +106,18 @@ class OpenAIFunctionAgent_Agents implements INode {
options.socketIO.to(options.socketIOClientId).emit('sourceDocuments', flatten(res.sourceDocuments))
sourceDocuments = res.sourceDocuments
}
if (res.usedTools) {
options.socketIO.to(options.socketIOClientId).emit('usedTools', res.usedTools)
usedTools = res.usedTools
}
} else {
res = await executor.invoke({ input }, { callbacks: [loggerHandler, ...callbacks] })
if (res.sourceDocuments) {
sourceDocuments = res.sourceDocuments
}
if (res.usedTools) {
usedTools = res.usedTools
}
}

await memory.addChatMessages(
Expand All @@ -126,7 +134,20 @@ class OpenAIFunctionAgent_Agents implements INode {
this.sessionId
)

return sourceDocuments.length ? { text: res?.output, sourceDocuments: flatten(sourceDocuments) } : res?.output
let finalRes = res?.output

if (sourceDocuments.length || usedTools.length) {
finalRes = { text: res?.output }
if (sourceDocuments.length) {
finalRes.sourceDocuments = flatten(sourceDocuments)
}
if (usedTools.length) {
finalRes.usedTools = usedTools
}
return finalRes
}

return finalRes
}
}

Expand Down
Loading

0 comments on commit cd4c659

Please sign in to comment.