Skip to content

Commit

Permalink
Passing state to tool so that we can use them in custom tools (#3103)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jrakru authored Aug 30, 2024
1 parent 7a5246d commit 2e45851
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 5 deletions.
46 changes: 41 additions & 5 deletions packages/components/nodes/sequentialagents/ToolNode/ToolNode.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
import { flatten } from 'lodash'
import { ICommonObject, IDatabaseEntity, INode, INodeData, INodeParams, ISeqAgentNode, IUsedTool } from '../../../src/Interface'
import {
ICommonObject,
IDatabaseEntity,
INode,
INodeData,
INodeParams,
ISeqAgentNode,
IUsedTool,
IStateWithMessages
} from '../../../src/Interface'
import { AIMessage, AIMessageChunk, BaseMessage, ToolMessage } from '@langchain/core/messages'
import { StructuredTool } from '@langchain/core/tools'
import { RunnableConfig } from '@langchain/core/runnables'
Expand All @@ -9,6 +18,7 @@ import { DataSource } from 'typeorm'
import { MessagesState, RunnableCallable, customGet, getVM } from '../commonUtils'
import { getVars, prepareSandboxVars } from '../../../src/utils'
import { ChatPromptTemplate } from '@langchain/core/prompts'
import { DynamicStructuredTool } from '../../tools/CustomTool/core'

const defaultApprovalPrompt = `You are about to execute tool: {tools}. Ask if user want to proceed`

Expand Down Expand Up @@ -350,7 +360,7 @@ class ToolNode_SeqAgents implements INode {
}
}

class ToolNode<T extends BaseMessage[] | MessagesState> extends RunnableCallable<T, T> {
class ToolNode<T extends IStateWithMessages | BaseMessage[] | MessagesState> extends RunnableCallable<T, BaseMessage[] | MessagesState> {
tools: StructuredTool[]
nodeData: INodeData
inputQuery: string
Expand All @@ -372,19 +382,45 @@ class ToolNode<T extends BaseMessage[] | MessagesState> extends RunnableCallable
this.options = options
}

private async run(input: BaseMessage[] | MessagesState, config: RunnableConfig): Promise<BaseMessage[] | MessagesState> {
const message = Array.isArray(input) ? input[input.length - 1] : input.messages[input.messages.length - 1]
private async run(input: T, config: RunnableConfig): Promise<BaseMessage[] | MessagesState> {
let messages: BaseMessage[]

// Check if input is an array of BaseMessage[]
if (Array.isArray(input)) {
messages = input
}
// Check if input is IStateWithMessages
else if ((input as IStateWithMessages).messages) {
messages = (input as IStateWithMessages).messages
}
// Handle MessagesState type
else {
messages = (input as MessagesState).messages
}

// Get the last message
const message = messages[messages.length - 1]

if (message._getType() !== 'ai') {
throw new Error('ToolNode only accepts AIMessages as input.')
}

// Extract all properties except messages for IStateWithMessages
const { messages: _, ...inputWithoutMessages } = Array.isArray(input) ? { messages: input } : input
const ChannelsWithoutMessages = {
state: inputWithoutMessages
}

const outputs = await Promise.all(
(message as AIMessage).tool_calls?.map(async (call) => {
const tool = this.tools.find((tool) => tool.name === call.name)
if (tool === undefined) {
throw new Error(`Tool ${call.name} not found.`)
}
if (tool && tool instanceof DynamicStructuredTool) {
// @ts-ignore
tool.setFlowObject(ChannelsWithoutMessages)
}
let output = await tool.invoke(call.args, config)
let sourceDocuments: Document[] = []
if (output?.includes(SOURCE_DOCUMENTS_PREFIX)) {
Expand Down Expand Up @@ -436,7 +472,7 @@ const getReturnOutput = async (
input: string,
options: ICommonObject,
outputs: ToolMessage[],
state: BaseMessage[] | MessagesState
state: ICommonObject
) => {
const appDataSource = options.appDataSource as DataSource
const databaseEntities = options.databaseEntities as IDatabaseEntity
Expand Down
4 changes: 4 additions & 0 deletions packages/components/src/Interface.ts
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,7 @@ export interface IVisionChatModal {
revertToOriginalModel(): void
setMultiModalOption(multiModalOption: IMultiModalOption): void
}
export interface IStateWithMessages extends ICommonObject {
messages: BaseMessage[]
[key: string]: any
}

0 comments on commit 2e45851

Please sign in to comment.