Skip to content

Commit

Permalink
refactor(core): Add support for multiple invocation of tools
Browse files Browse the repository at this point in the history
  • Loading branch information
netroy committed Dec 10, 2024
1 parent 5817ac1 commit fe4fc7d
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 211 deletions.
73 changes: 33 additions & 40 deletions packages/core/src/CreateNodeAsTool.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import { DynamicStructuredTool } from '@langchain/core/tools';
import type {
IExecuteFunctions,
INode,
INodeParameters,
INodeType,
ISupplyDataFunctions,
ITaskDataConnections,
} from 'n8n-workflow';
import { jsonParse, NodeConnectionType, NodeOperationError } from 'n8n-workflow';
import { z } from 'zod';
Expand All @@ -16,22 +18,25 @@ interface FromAIArgument {
defaultValue?: string | number | boolean | Record<string, unknown>;
}

type ParserOptions = {
node: INode;
nodeType: INodeType;
contextFactory: (runIndex: number, inputData: ITaskDataConnections) => ISupplyDataFunctions;
};

/**
* AIParametersParser
*
* This class encapsulates the logic for parsing node parameters, extracting $fromAI calls,
* generating Zod schemas, and creating LangChain tools.
*/
class AIParametersParser {
private ctx: ISupplyDataFunctions;
private runIndex = 0;

/**
* Constructs an instance of AIParametersParser.
* @param ctx The execution context.
*/
constructor(ctx: ISupplyDataFunctions) {
this.ctx = ctx;
}
constructor(private readonly options: ParserOptions) {}

/**
* Generates a Zod schema based on the provided FromAIArgument placeholder.
Expand Down Expand Up @@ -162,14 +167,14 @@ class AIParametersParser {
} catch (error) {
// If parsing fails, throw an ApplicationError with details
throw new NodeOperationError(
this.ctx.getNode(),
this.options.node,
`Failed to parse $fromAI arguments: ${argsString}: ${error}`,
);
}
} else {
// Log an error if parentheses are unbalanced
throw new NodeOperationError(
this.ctx.getNode(),
this.options.node,
`Unbalanced parentheses while parsing $fromAI call: ${str.slice(startIndex)}`,
);
}
Expand Down Expand Up @@ -254,7 +259,7 @@ class AIParametersParser {
const type = cleanArgs?.[2] || 'string';

if (!['string', 'number', 'boolean', 'json'].includes(type.toLowerCase())) {
throw new NodeOperationError(this.ctx.getNode(), `Invalid type: ${type}`);
throw new NodeOperationError(this.options.node, `Invalid type: ${type}`);
}

return {
Expand Down Expand Up @@ -315,13 +320,12 @@ class AIParametersParser {

/**
* Creates a DynamicStructuredTool from a node.
* @param node The node type.
* @param nodeParameters The parameters of the node.
* @returns A DynamicStructuredTool instance.
*/
public createTool(node: INodeType, nodeParameters: INodeParameters): DynamicStructuredTool {
public createTool(): DynamicStructuredTool {
const { node, nodeType } = this.options;
const collectedArguments: FromAIArgument[] = [];
this.traverseNodeParameters(nodeParameters, collectedArguments);
this.traverseNodeParameters(node.parameters, collectedArguments);

// Validate each collected argument
const nameValidationRegex = /^[a-zA-Z0-9_-]{1,64}$/;
Expand All @@ -331,7 +335,7 @@ class AIParametersParser {
const isEmptyError = 'You must specify a key when using $fromAI()';
const isInvalidError = `Parameter key \`${argument.key}\` is invalid`;
const error = new Error(argument.key.length === 0 ? isEmptyError : isInvalidError);
throw new NodeOperationError(this.ctx.getNode(), error, {
throw new NodeOperationError(node, error, {
description:
'Invalid parameter key, must be between 1 and 64 characters long and only contain letters, numbers, underscores, and hyphens',
});
Expand All @@ -348,7 +352,7 @@ class AIParametersParser {
) {
// If not, throw an error for inconsistent duplicate keys
throw new NodeOperationError(
this.ctx.getNode(),
node,
`Duplicate key '${argument.key}' found with different description or type`,
{
description:
Expand Down Expand Up @@ -378,37 +382,38 @@ class AIParametersParser {
}, {});

const schema = z.object(schemaObj).required();
const description = this.getDescription(node, nodeParameters);
const nodeName = this.ctx.getNode().name.replace(/ /g, '_');
const name = nodeName || node.description.name;
const description = this.getDescription(nodeType, node.parameters);
const nodeName = node.name.replace(/ /g, '_');
const name = nodeName || nodeType.description.name;

const tool = new DynamicStructuredTool({
name,
description,
schema,
func: async (functionArgs: z.infer<typeof schema>) => {
const { index } = this.ctx.addInputData(NodeConnectionType.AiTool, [
[{ json: functionArgs }],
]);
func: async (toolArgs: z.infer<typeof schema>) => {
const context = this.options.contextFactory(this.runIndex, {});
context.addInputData(NodeConnectionType.AiTool, [[{ json: toolArgs }]]);

try {
// Execute the node with the proxied context
const result = await node.execute?.bind(this.ctx as IExecuteFunctions)();
const result = await nodeType.execute?.call(context as IExecuteFunctions);

// Process and map the results
const mappedResults = result?.[0]?.flatMap((item) => item.json);

// Add output data to the context
this.ctx.addOutputData(NodeConnectionType.AiTool, index, [
context.addOutputData(NodeConnectionType.AiTool, this.runIndex, [
[{ json: { response: mappedResults } }],
]);

// Return the stringified results
return JSON.stringify(mappedResults);
} catch (error) {
const nodeError = new NodeOperationError(this.ctx.getNode(), error as Error);
this.ctx.addOutputData(NodeConnectionType.AiTool, index, nodeError);
const nodeError = new NodeOperationError(this.options.node, error as Error);
context.addOutputData(NodeConnectionType.AiTool, this.runIndex, nodeError);
return 'Error during node execution: ' + nodeError.description;
} finally {
this.runIndex++;
}
},
});
Expand All @@ -421,20 +426,8 @@ class AIParametersParser {
* Converts node into LangChain tool by analyzing node parameters,
* identifying placeholders using the $fromAI function, and generating a Zod schema. It then creates
* a DynamicStructuredTool that can be used in LangChain workflows.
*
* @param ctx The execution context.
* @param node The node type.
* @param nodeParameters The parameters of the node.
* @returns An object containing the DynamicStructuredTool instance.
*/
export function createNodeAsTool(
ctx: ISupplyDataFunctions,
node: INodeType,
nodeParameters: INodeParameters,
) {
const parser = new AIParametersParser(ctx);

return {
response: parser.createTool(node, nodeParameters),
};
export function createNodeAsTool(options: ParserOptions) {
const parser = new AIParametersParser(options);
return { response: parser.createTool() };
}
141 changes: 74 additions & 67 deletions packages/core/src/NodeExecuteFunctions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ import type {
DeduplicationScope,
DeduplicationItemTypes,
ICheckProcessedContextData,
ISupplyDataFunctions,
WebhookType,
SchedulingFunctions,
SupplyData,
} from 'n8n-workflow';
import {
NodeConnectionType,
Expand Down Expand Up @@ -2023,9 +2023,9 @@ export async function getInputConnectionData(
this: IAllExecuteFunctions,
workflow: Workflow,
runExecutionData: IRunExecutionData,
runIndex: number,
parentRunIndex: number,
connectionInputData: INodeExecutionData[],
inputData: ITaskDataConnections,
parentInputData: ITaskDataConnections,
additionalData: IWorkflowExecuteAdditionalData,
executeData: IExecuteData,
mode: WorkflowExecuteMode,
Expand All @@ -2034,10 +2034,13 @@ export async function getInputConnectionData(
itemIndex: number,
abortSignal?: AbortSignal,
): Promise<unknown> {
const node = this.getNode();
const nodeType = workflow.nodeTypes.getByNameAndVersion(node.type, node.typeVersion);
const parentNode = this.getNode();
const parentNodeType = workflow.nodeTypes.getByNameAndVersion(
parentNode.type,
parentNode.typeVersion,
);

const inputs = NodeHelpers.getNodeInputs(workflow, node, nodeType.description);
const inputs = NodeHelpers.getNodeInputs(workflow, parentNode, parentNodeType.description);

let inputConfiguration = inputs.find((input) => {
if (typeof input === 'string') {
Expand All @@ -2048,7 +2051,7 @@ export async function getInputConnectionData(

if (inputConfiguration === undefined) {
throw new ApplicationError('Node does not have input of type', {
extra: { nodeName: node.name, connectionType },
extra: { nodeName: parentNode.name, connectionType },
});
}

Expand All @@ -2059,14 +2062,14 @@ export async function getInputConnectionData(
}

const connectedNodes = workflow
.getParentNodes(node.name, connectionType, 1)
.getParentNodes(parentNode.name, connectionType, 1)
.map((nodeName) => workflow.getNode(nodeName) as INode)
.filter((connectedNode) => connectedNode.disabled !== true);

if (connectedNodes.length === 0) {
if (inputConfiguration.required) {
throw new NodeOperationError(
node,
parentNode,
`A ${inputConfiguration?.displayName ?? connectionType} sub-node must be connected and enabled`,
);
}
Expand All @@ -2078,82 +2081,86 @@ export async function getInputConnectionData(
connectedNodes.length > inputConfiguration.maxConnections
) {
throw new NodeOperationError(
node,
parentNode,
`Only ${inputConfiguration.maxConnections} ${connectionType} sub-nodes are/is allowed to be connected`,
);
}

const constParentNodes = connectedNodes.map(async (connectedNode) => {
const nodeType = workflow.nodeTypes.getByNameAndVersion(
const nodes: SupplyData[] = [];
for (const connectedNode of connectedNodes) {
const connectedNodeType = workflow.nodeTypes.getByNameAndVersion(
connectedNode.type,
connectedNode.typeVersion,
);
const context = new SupplyDataContext(
workflow,
connectedNode,
additionalData,
mode,
runExecutionData,
runIndex,
connectionInputData,
inputData,
executeData,
closeFunctions,
abortSignal,
);
const contextFactory = (runIndex: number, inputData: ITaskDataConnections) =>
new SupplyDataContext(
workflow,
connectedNode,
additionalData,
mode,
runExecutionData,
runIndex,
connectionInputData,
inputData,
connectionType,
executeData,
closeFunctions,
abortSignal,
);

if (!nodeType.supplyData) {
if (nodeType.description.outputs.includes(NodeConnectionType.AiTool)) {
nodeType.supplyData = async function (this: ISupplyDataFunctions) {
return createNodeAsTool(this, nodeType, this.getNode().parameters);
};
if (!connectedNodeType.supplyData) {
if (connectedNodeType.description.outputs.includes(NodeConnectionType.AiTool)) {
const supplyData = createNodeAsTool({
node: connectedNode,
nodeType: connectedNodeType,
contextFactory,
});
nodes.push(supplyData);
} else {
throw new ApplicationError('Node does not have a `supplyData` method defined', {
extra: { nodeName: connectedNode.name },
});
}
}
} else {
const context = contextFactory(parentRunIndex, parentInputData);
try {
const supplyData = await connectedNodeType.supplyData.call(context, itemIndex);
if (supplyData.closeFunction) {
closeFunctions.push(supplyData.closeFunction);
}
nodes.push(supplyData);
} catch (error) {
// Propagate errors from sub-nodes
if (error.functionality === 'configuration-node') throw error;
if (!(error instanceof ExecutionBaseError)) {
error = new NodeOperationError(connectedNode, error, {
itemIndex,
});
}

try {
const response = await nodeType.supplyData.call(context, itemIndex);
if (response.closeFunction) {
closeFunctions.push(response.closeFunction);
}
return response;
} catch (error) {
// Propagate errors from sub-nodes
if (error.functionality === 'configuration-node') throw error;
if (!(error instanceof ExecutionBaseError)) {
error = new NodeOperationError(connectedNode, error, {
let currentNodeRunIndex = 0;
if (runExecutionData.resultData.runData.hasOwnProperty(parentNode.name)) {
currentNodeRunIndex = runExecutionData.resultData.runData[parentNode.name].length;
}

// Display the error on the node which is causing it
await context.addExecutionDataFunctions(
'input',
error,
connectionType,
parentNode.name,
currentNodeRunIndex,
);

// Display on the calling node which node has the error
throw new NodeOperationError(connectedNode, `Error in sub-node ${connectedNode.name}`, {
itemIndex,
functionality: 'configuration-node',
description: error.message,
});
}

let currentNodeRunIndex = 0;
if (runExecutionData.resultData.runData.hasOwnProperty(node.name)) {
currentNodeRunIndex = runExecutionData.resultData.runData[node.name].length;
}

// Display the error on the node which is causing it
await context.addExecutionDataFunctions(
'input',
error,
connectionType,
node.name,
currentNodeRunIndex,
);

// Display on the calling node which node has the error
throw new NodeOperationError(connectedNode, `Error in sub-node ${connectedNode.name}`, {
itemIndex,
functionality: 'configuration-node',
description: error.message,
});
}
});

// Validate the inputs
const nodes = await Promise.all(constParentNodes);
}

return inputConfiguration.maxConnections === 1
? (nodes || [])[0]?.response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ describe('SupplyDataContext', () => {
runIndex,
connectionInputData,
inputData,
connectionType,
executeData,
[closeFn],
abortSignal,
Expand Down
Loading

0 comments on commit fe4fc7d

Please sign in to comment.