Skip to content

Commit

Permalink
Feature/Add prepend messages to memory (FlowiseAI#2410)
Browse files Browse the repository at this point in the history
add prepend messages to memory
  • Loading branch information
HenryHengZJ authored May 20, 2024
1 parent 816436f commit 8caca47
Show file tree
Hide file tree
Showing 27 changed files with 219 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ const prepareAgent = async (
const systemMessage = nodeData.inputs?.systemMessage as string
const memoryKey = memory.memoryKey ? memory.memoryKey : 'chat_history'
const inputKey = memory.inputKey ? memory.inputKey : 'input'
const prependMessages = options?.prependMessages

const outputParser = ChatConversationalAgent.getDefaultOutputParser({
llm: model,
Expand Down Expand Up @@ -240,7 +241,7 @@ const prepareAgent = async (
[inputKey]: (i: { input: string; steps: AgentStep[] }) => i.input,
agent_scratchpad: async (i: { input: string; steps: AgentStep[] }) => await constructScratchPad(i.steps),
[memoryKey]: async (_: { input: string; steps: AgentStep[] }) => {
const messages = (await memory.getChatMessages(flowObj?.sessionId, true)) as BaseMessage[]
const messages = (await memory.getChatMessages(flowObj?.sessionId, true, prependMessages)) as BaseMessage[]
return messages ?? []
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class ConversationalRetrievalAgent_Agents implements INode {
}

async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> {
return prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input })
return prepareAgent(nodeData, options, { sessionId: this.sessionId, chatId: options.chatId, input })
}

async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> {
Expand All @@ -102,7 +102,7 @@ class ConversationalRetrievalAgent_Agents implements INode {
}
}

const executor = prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input })
const executor = prepareAgent(nodeData, options, { sessionId: this.sessionId, chatId: options.chatId, input })

const loggerHandler = new ConsoleCallbackHandler(options.logger)
const callbacks = await additionalCallbacks(nodeData, options)
Expand Down Expand Up @@ -134,7 +134,7 @@ class ConversationalRetrievalAgent_Agents implements INode {
}
}

const prepareAgent = (nodeData: INodeData, flowObj: { sessionId?: string; chatId?: string; input?: string }) => {
const prepareAgent = (nodeData: INodeData, options: ICommonObject, flowObj: { sessionId?: string; chatId?: string; input?: string }) => {
const model = nodeData.inputs?.model as ChatOpenAI
const memory = nodeData.inputs?.memory as FlowiseMemory
const systemMessage = nodeData.inputs?.systemMessage as string
Expand All @@ -143,6 +143,7 @@ const prepareAgent = (nodeData: INodeData, flowObj: { sessionId?: string; chatId
tools = flatten(tools)
const memoryKey = memory.memoryKey ? memory.memoryKey : 'chat_history'
const inputKey = memory.inputKey ? memory.inputKey : 'input'
const prependMessages = options?.prependMessages

const prompt = ChatPromptTemplate.fromMessages([
['ai', systemMessage ? systemMessage : defaultMessage],
Expand All @@ -160,7 +161,7 @@ const prepareAgent = (nodeData: INodeData, flowObj: { sessionId?: string; chatId
[inputKey]: (i: { input: string; steps: AgentStep[] }) => i.input,
agent_scratchpad: (i: { input: string; steps: AgentStep[] }) => formatAgentSteps(i.steps),
[memoryKey]: async (_: { input: string; steps: AgentStep[] }) => {
const messages = (await memory.getChatMessages(flowObj?.sessionId, true)) as BaseMessage[]
const messages = (await memory.getChatMessages(flowObj?.sessionId, true, prependMessages)) as BaseMessage[]
return messages ?? []
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class MistralAIToolAgent_Agents implements INode {
}

async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> {
return prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input })
return prepareAgent(nodeData, options, { sessionId: this.sessionId, chatId: options.chatId, input })
}

async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
Expand All @@ -100,7 +100,7 @@ class MistralAIToolAgent_Agents implements INode {
}
}

const executor = prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input })
const executor = prepareAgent(nodeData, options, { sessionId: this.sessionId, chatId: options.chatId, input })

const loggerHandler = new ConsoleCallbackHandler(options.logger)
const callbacks = await additionalCallbacks(nodeData, options)
Expand Down Expand Up @@ -161,7 +161,7 @@ class MistralAIToolAgent_Agents implements INode {
}
}

const prepareAgent = (nodeData: INodeData, flowObj: { sessionId?: string; chatId?: string; input?: string }) => {
const prepareAgent = (nodeData: INodeData, options: ICommonObject, flowObj: { sessionId?: string; chatId?: string; input?: string }) => {
const model = nodeData.inputs?.model as ChatOpenAI
const memory = nodeData.inputs?.memory as FlowiseMemory
const maxIterations = nodeData.inputs?.maxIterations as string
Expand All @@ -170,6 +170,7 @@ const prepareAgent = (nodeData: INodeData, flowObj: { sessionId?: string; chatId
tools = flatten(tools)
const memoryKey = memory.memoryKey ? memory.memoryKey : 'chat_history'
const inputKey = memory.inputKey ? memory.inputKey : 'input'
const prependMessages = options?.prependMessages

const prompt = ChatPromptTemplate.fromMessages([
['system', systemMessage ? systemMessage : `You are a helpful AI assistant.`],
Expand All @@ -187,7 +188,7 @@ const prepareAgent = (nodeData: INodeData, flowObj: { sessionId?: string; chatId
[inputKey]: (i: { input: string; steps: AgentStep[] }) => i.input,
agent_scratchpad: (i: { input: string; steps: AgentStep[] }) => formatAgentSteps(i.steps),
[memoryKey]: async (_: { input: string; steps: AgentStep[] }) => {
const messages = (await memory.getChatMessages(flowObj?.sessionId, true)) as BaseMessage[]
const messages = (await memory.getChatMessages(flowObj?.sessionId, true, prependMessages)) as BaseMessage[]
return messages ?? []
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class OpenAIFunctionAgent_Agents implements INode {
}

async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> {
return prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input })
return prepareAgent(nodeData, options, { sessionId: this.sessionId, chatId: options.chatId, input })
}

async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
Expand All @@ -99,7 +99,7 @@ class OpenAIFunctionAgent_Agents implements INode {
}
}

const executor = prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input })
const executor = prepareAgent(nodeData, options, { sessionId: this.sessionId, chatId: options.chatId, input })

const loggerHandler = new ConsoleCallbackHandler(options.logger)
const callbacks = await additionalCallbacks(nodeData, options)
Expand Down Expand Up @@ -160,7 +160,7 @@ class OpenAIFunctionAgent_Agents implements INode {
}
}

const prepareAgent = (nodeData: INodeData, flowObj: { sessionId?: string; chatId?: string; input?: string }) => {
const prepareAgent = (nodeData: INodeData, options: ICommonObject, flowObj: { sessionId?: string; chatId?: string; input?: string }) => {
const model = nodeData.inputs?.model as ChatOpenAI
const maxIterations = nodeData.inputs?.maxIterations as string
const memory = nodeData.inputs?.memory as FlowiseMemory
Expand All @@ -169,6 +169,7 @@ const prepareAgent = (nodeData: INodeData, flowObj: { sessionId?: string; chatId
tools = flatten(tools)
const memoryKey = memory.memoryKey ? memory.memoryKey : 'chat_history'
const inputKey = memory.inputKey ? memory.inputKey : 'input'
const prependMessages = options?.prependMessages

const prompt = ChatPromptTemplate.fromMessages([
['system', systemMessage ? systemMessage : `You are a helpful AI assistant.`],
Expand All @@ -186,7 +187,7 @@ const prepareAgent = (nodeData: INodeData, flowObj: { sessionId?: string; chatId
[inputKey]: (i: { input: string; steps: AgentStep[] }) => i.input,
agent_scratchpad: (i: { input: string; steps: AgentStep[] }) => formatAgentSteps(i.steps),
[memoryKey]: async (_: { input: string; steps: AgentStep[] }) => {
const messages = (await memory.getChatMessages(flowObj?.sessionId, true)) as BaseMessage[]
const messages = (await memory.getChatMessages(flowObj?.sessionId, true, prependMessages)) as BaseMessage[]
return messages ?? []
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class OpenAIToolAgent_Agents implements INode {
}

async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> {
return prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input })
return prepareAgent(nodeData, options, { sessionId: this.sessionId, chatId: options.chatId, input })
}

async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
Expand All @@ -100,7 +100,7 @@ class OpenAIToolAgent_Agents implements INode {
}
}

const executor = prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input })
const executor = prepareAgent(nodeData, options, { sessionId: this.sessionId, chatId: options.chatId, input })

const loggerHandler = new ConsoleCallbackHandler(options.logger)
const callbacks = await additionalCallbacks(nodeData, options)
Expand Down Expand Up @@ -161,7 +161,7 @@ class OpenAIToolAgent_Agents implements INode {
}
}

const prepareAgent = (nodeData: INodeData, flowObj: { sessionId?: string; chatId?: string; input?: string }) => {
const prepareAgent = (nodeData: INodeData, options: ICommonObject, flowObj: { sessionId?: string; chatId?: string; input?: string }) => {
const model = nodeData.inputs?.model as ChatOpenAI
const maxIterations = nodeData.inputs?.maxIterations as string
const memory = nodeData.inputs?.memory as FlowiseMemory
Expand All @@ -170,6 +170,7 @@ const prepareAgent = (nodeData: INodeData, flowObj: { sessionId?: string; chatId
tools = flatten(tools)
const memoryKey = memory.memoryKey ? memory.memoryKey : 'chat_history'
const inputKey = memory.inputKey ? memory.inputKey : 'input'
const prependMessages = options?.prependMessages

const prompt = ChatPromptTemplate.fromMessages([
['system', systemMessage ? systemMessage : `You are a helpful AI assistant.`],
Expand All @@ -185,7 +186,7 @@ const prepareAgent = (nodeData: INodeData, flowObj: { sessionId?: string; chatId
[inputKey]: (i: { input: string; steps: ToolsAgentStep[] }) => i.input,
agent_scratchpad: (i: { input: string; steps: ToolsAgentStep[] }) => formatToOpenAIToolMessages(i.steps),
[memoryKey]: async (_: { input: string; steps: ToolsAgentStep[] }) => {
const messages = (await memory.getChatMessages(flowObj?.sessionId, true)) as BaseMessage[]
const messages = (await memory.getChatMessages(flowObj?.sessionId, true, prependMessages)) as BaseMessage[]
return messages ?? []
}
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,12 @@ class OpenAIFunctionAgent_LlamaIndex_Agents implements INode {
return null
}

async run(nodeData: INodeData, input: string): Promise<string | ICommonObject> {
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | ICommonObject> {
const memory = nodeData.inputs?.memory as FlowiseMemory
const model = nodeData.inputs?.model as OpenAI
const systemMessage = nodeData.inputs?.systemMessage as string
const prependMessages = options?.prependMessages

let tools = nodeData.inputs?.tools
tools = flatten(tools)

Expand All @@ -77,7 +79,7 @@ class OpenAIFunctionAgent_LlamaIndex_Agents implements INode {
})
}

const msgs = (await memory.getChatMessages(this.sessionId, false)) as IMessage[]
const msgs = (await memory.getChatMessages(this.sessionId, false, prependMessages)) as IMessage[]
for (const message of msgs) {
if (message.type === 'apiMessage') {
chatHistory.push({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ class ReActAgentChat_Agents implements INode {
const model = nodeData.inputs?.model as BaseChatModel
let tools = nodeData.inputs?.tools as Tool[]
const moderations = nodeData.inputs?.inputModeration as Moderation[]
const prependMessages = options?.prependMessages

if (moderations && moderations.length > 0) {
try {
Expand Down Expand Up @@ -134,7 +135,7 @@ class ReActAgentChat_Agents implements INode {

const callbacks = await additionalCallbacks(nodeData, options)

const chatHistory = ((await memory.getChatMessages(this.sessionId, false)) as IMessage[]) ?? []
const chatHistory = ((await memory.getChatMessages(this.sessionId, false, prependMessages)) as IMessage[]) ?? []
const chatHistoryString = chatHistory.map((hist) => hist.message).join('\\n')

const result = await executor.invoke({ input, chat_history: chatHistoryString }, { callbacks })
Expand Down
3 changes: 2 additions & 1 deletion packages/components/nodes/agents/ToolAgent/ToolAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ const prepareAgent = async (
tools = flatten(tools)
const memoryKey = memory.memoryKey ? memory.memoryKey : 'chat_history'
const inputKey = memory.inputKey ? memory.inputKey : 'input'
const prependMessages = options?.prependMessages

const prompt = ChatPromptTemplate.fromMessages([
['system', systemMessage],
Expand Down Expand Up @@ -239,7 +240,7 @@ const prepareAgent = async (
[inputKey]: (i: { input: string; steps: ToolsAgentStep[] }) => i.input,
agent_scratchpad: (i: { input: string; steps: ToolsAgentStep[] }) => formatToOpenAIToolMessages(i.steps),
[memoryKey]: async (_: { input: string; steps: ToolsAgentStep[] }) => {
const messages = (await memory.getChatMessages(flowObj?.sessionId, true)) as BaseMessage[]
const messages = (await memory.getChatMessages(flowObj?.sessionId, true, prependMessages)) as BaseMessage[]
return messages ?? []
}
},
Expand Down
11 changes: 8 additions & 3 deletions packages/components/nodes/agents/XMLAgent/XMLAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class XMLAgent_Agents implements INode {
return formatResponse(e.message)
}
}
const executor = await prepareAgent(nodeData, { sessionId: this.sessionId, chatId: options.chatId, input })
const executor = await prepareAgent(nodeData, options, { sessionId: this.sessionId, chatId: options.chatId, input })

const loggerHandler = new ConsoleCallbackHandler(options.logger)
const callbacks = await additionalCallbacks(nodeData, options)
Expand Down Expand Up @@ -183,7 +183,11 @@ class XMLAgent_Agents implements INode {
}
}

const prepareAgent = async (nodeData: INodeData, flowObj: { sessionId?: string; chatId?: string; input?: string }) => {
const prepareAgent = async (
nodeData: INodeData,
options: ICommonObject,
flowObj: { sessionId?: string; chatId?: string; input?: string }
) => {
const model = nodeData.inputs?.model as BaseChatModel
const maxIterations = nodeData.inputs?.maxIterations as string
const memory = nodeData.inputs?.memory as FlowiseMemory
Expand All @@ -192,6 +196,7 @@ const prepareAgent = async (nodeData: INodeData, flowObj: { sessionId?: string;
tools = flatten(tools)
const inputKey = memory.inputKey ? memory.inputKey : 'input'
const memoryKey = memory.memoryKey ? memory.memoryKey : 'chat_history'
const prependMessages = options?.prependMessages

let promptMessage = systemMessage ? systemMessage : defaultSystemMessage
if (memory.memoryKey) promptMessage = promptMessage.replaceAll('{chat_history}', `{${memory.memoryKey}}`)
Expand All @@ -210,7 +215,7 @@ const prepareAgent = async (nodeData: INodeData, flowObj: { sessionId?: string;

const llmWithStop = model.bind({ stop: ['</tool_input>', '</final_answer>'] })

const messages = (await memory.getChatMessages(flowObj.sessionId, false)) as IMessage[]
const messages = (await memory.getChatMessages(flowObj.sessionId, false, prependMessages)) as IMessage[]
let chatHistoryMsgTxt = ''
for (const message of messages) {
if (message.type === 'apiMessage') {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ const prepareChain = async (nodeData: INodeData, options: ICommonObject, session
let model = nodeData.inputs?.model as BaseChatModel
const memory = nodeData.inputs?.memory as FlowiseMemory
const memoryKey = memory.memoryKey ?? 'chat_history'
const prependMessages = options?.prependMessages

let messageContent: MessageContentImageUrl[] = []
if (llmSupportsVision(model)) {
Expand Down Expand Up @@ -252,7 +253,7 @@ const prepareChain = async (nodeData: INodeData, options: ICommonObject, session
{
[inputKey]: (input: { input: string }) => input.input,
[memoryKey]: async () => {
const history = await memory.getChatMessages(sessionId, true)
const history = await memory.getChatMessages(sessionId, true, prependMessages)
return history
},
...promptVariables
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ class ConversationalRetrievalQAChain_Chains implements INode {
const rephrasePrompt = nodeData.inputs?.rephrasePrompt as string
const responsePrompt = nodeData.inputs?.responsePrompt as string
const returnSourceDocuments = nodeData.inputs?.returnSourceDocuments as boolean
const prependMessages = options?.prependMessages

const appDataSource = options.appDataSource as DataSource
const databaseEntities = options.databaseEntities as IDatabaseEntity
Expand Down Expand Up @@ -210,7 +211,7 @@ class ConversationalRetrievalQAChain_Chains implements INode {
}
const answerChain = createChain(model, vectorStoreRetriever, rephrasePrompt, customResponsePrompt)

const history = ((await memory.getChatMessages(this.sessionId, false)) as IMessage[]) ?? []
const history = ((await memory.getChatMessages(this.sessionId, false, prependMessages)) as IMessage[]) ?? []

const loggerHandler = new ConsoleCallbackHandler(options.logger)
const additionalCallback = await additionalCallbacks(nodeData, options)
Expand Down Expand Up @@ -401,7 +402,11 @@ class BufferMemory extends FlowiseMemory implements MemoryMethods {
this.chatflowid = fields.chatflowid
}

async getChatMessages(overrideSessionId = '', returnBaseMessages = false): Promise<IMessage[] | BaseMessage[]> {
async getChatMessages(
overrideSessionId = '',
returnBaseMessages = false,
prependMessages?: IMessage[]
): Promise<IMessage[] | BaseMessage[]> {
if (!overrideSessionId) return []

const chatMessage = await this.appDataSource.getRepository(this.databaseEntities['ChatMessage']).find({
Expand All @@ -414,6 +419,10 @@ class BufferMemory extends FlowiseMemory implements MemoryMethods {
}
})

if (prependMessages?.length) {
chatMessage.unshift(...prependMessages)
}

if (returnBaseMessages) {
return mapChatMessageToBaseMessage(chatMessage)
}
Expand Down
Loading

0 comments on commit 8caca47

Please sign in to comment.