Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(core): Fix support for multiple invocation of AI tools #12141

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 33 additions & 40 deletions packages/core/src/CreateNodeAsTool.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import { DynamicStructuredTool } from '@langchain/core/tools';
import type {
IExecuteFunctions,
INode,
INodeParameters,
INodeType,
ISupplyDataFunctions,
ITaskDataConnections,
} from 'n8n-workflow';
import { jsonParse, NodeConnectionType, NodeOperationError } from 'n8n-workflow';
import { z } from 'zod';
Expand All @@ -16,22 +18,25 @@ interface FromAIArgument {
defaultValue?: string | number | boolean | Record<string, unknown>;
}

type ParserOptions = {
node: INode;
nodeType: INodeType;
contextFactory: (runIndex: number, inputData: ITaskDataConnections) => ISupplyDataFunctions;
};

/**
* AIParametersParser
*
* This class encapsulates the logic for parsing node parameters, extracting $fromAI calls,
* generating Zod schemas, and creating LangChain tools.
*/
class AIParametersParser {
private ctx: ISupplyDataFunctions;
private runIndex = 0;

/**
* Constructs an instance of AIParametersParser.
* @param ctx The execution context.
*/
constructor(ctx: ISupplyDataFunctions) {
this.ctx = ctx;
}
constructor(private readonly options: ParserOptions) {}

/**
* Generates a Zod schema based on the provided FromAIArgument placeholder.
Expand Down Expand Up @@ -162,14 +167,14 @@ class AIParametersParser {
} catch (error) {
// If parsing fails, throw an ApplicationError with details
throw new NodeOperationError(
this.ctx.getNode(),
this.options.node,
`Failed to parse $fromAI arguments: ${argsString}: ${error}`,
);
}
} else {
// Log an error if parentheses are unbalanced
throw new NodeOperationError(
this.ctx.getNode(),
this.options.node,
`Unbalanced parentheses while parsing $fromAI call: ${str.slice(startIndex)}`,
);
}
Expand Down Expand Up @@ -254,7 +259,7 @@ class AIParametersParser {
const type = cleanArgs?.[2] || 'string';

if (!['string', 'number', 'boolean', 'json'].includes(type.toLowerCase())) {
throw new NodeOperationError(this.ctx.getNode(), `Invalid type: ${type}`);
throw new NodeOperationError(this.options.node, `Invalid type: ${type}`);
}

return {
Expand Down Expand Up @@ -315,13 +320,12 @@ class AIParametersParser {

/**
* Creates a DynamicStructuredTool from a node.
* @param node The node type.
* @param nodeParameters The parameters of the node.
* @returns A DynamicStructuredTool instance.
*/
public createTool(node: INodeType, nodeParameters: INodeParameters): DynamicStructuredTool {
public createTool(): DynamicStructuredTool {
const { node, nodeType } = this.options;
const collectedArguments: FromAIArgument[] = [];
this.traverseNodeParameters(nodeParameters, collectedArguments);
this.traverseNodeParameters(node.parameters, collectedArguments);

// Validate each collected argument
const nameValidationRegex = /^[a-zA-Z0-9_-]{1,64}$/;
Expand All @@ -331,7 +335,7 @@ class AIParametersParser {
const isEmptyError = 'You must specify a key when using $fromAI()';
const isInvalidError = `Parameter key \`${argument.key}\` is invalid`;
const error = new Error(argument.key.length === 0 ? isEmptyError : isInvalidError);
throw new NodeOperationError(this.ctx.getNode(), error, {
throw new NodeOperationError(node, error, {
description:
'Invalid parameter key, must be between 1 and 64 characters long and only contain letters, numbers, underscores, and hyphens',
});
Expand All @@ -348,7 +352,7 @@ class AIParametersParser {
) {
// If not, throw an error for inconsistent duplicate keys
throw new NodeOperationError(
this.ctx.getNode(),
node,
`Duplicate key '${argument.key}' found with different description or type`,
{
description:
Expand Down Expand Up @@ -378,37 +382,38 @@ class AIParametersParser {
}, {});

const schema = z.object(schemaObj).required();
const description = this.getDescription(node, nodeParameters);
const nodeName = this.ctx.getNode().name.replace(/ /g, '_');
const name = nodeName || node.description.name;
const description = this.getDescription(nodeType, node.parameters);
const nodeName = node.name.replace(/ /g, '_');
const name = nodeName || nodeType.description.name;

const tool = new DynamicStructuredTool({
name,
description,
schema,
func: async (functionArgs: z.infer<typeof schema>) => {
const { index } = this.ctx.addInputData(NodeConnectionType.AiTool, [
[{ json: functionArgs }],
]);
func: async (toolArgs: z.infer<typeof schema>) => {
const context = this.options.contextFactory(this.runIndex, {});
context.addInputData(NodeConnectionType.AiTool, [[{ json: toolArgs }]]);

try {
// Execute the node with the proxied context
const result = await node.execute?.bind(this.ctx as IExecuteFunctions)();
const result = await nodeType.execute?.call(context as IExecuteFunctions);

// Process and map the results
const mappedResults = result?.[0]?.flatMap((item) => item.json);

// Add output data to the context
this.ctx.addOutputData(NodeConnectionType.AiTool, index, [
context.addOutputData(NodeConnectionType.AiTool, this.runIndex, [
[{ json: { response: mappedResults } }],
]);

// Return the stringified results
return JSON.stringify(mappedResults);
} catch (error) {
const nodeError = new NodeOperationError(this.ctx.getNode(), error as Error);
this.ctx.addOutputData(NodeConnectionType.AiTool, index, nodeError);
const nodeError = new NodeOperationError(this.options.node, error as Error);
context.addOutputData(NodeConnectionType.AiTool, this.runIndex, nodeError);
return 'Error during node execution: ' + nodeError.description;
} finally {
this.runIndex++;
}
},
});
Expand All @@ -421,20 +426,8 @@ class AIParametersParser {
* Converts node into LangChain tool by analyzing node parameters,
* identifying placeholders using the $fromAI function, and generating a Zod schema. It then creates
* a DynamicStructuredTool that can be used in LangChain workflows.
*
* @param ctx The execution context.
* @param node The node type.
* @param nodeParameters The parameters of the node.
* @returns An object containing the DynamicStructuredTool instance.
*/
export function createNodeAsTool(
ctx: ISupplyDataFunctions,
node: INodeType,
nodeParameters: INodeParameters,
) {
const parser = new AIParametersParser(ctx);

return {
response: parser.createTool(node, nodeParameters),
};
export function createNodeAsTool(options: ParserOptions) {
const parser = new AIParametersParser(options);
return { response: parser.createTool() };
}
141 changes: 74 additions & 67 deletions packages/core/src/NodeExecuteFunctions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ import type {
DeduplicationScope,
DeduplicationItemTypes,
ICheckProcessedContextData,
ISupplyDataFunctions,
WebhookType,
SchedulingFunctions,
SupplyData,
} from 'n8n-workflow';
import {
NodeConnectionType,
Expand Down Expand Up @@ -2023,9 +2023,9 @@ export async function getInputConnectionData(
this: IAllExecuteFunctions,
workflow: Workflow,
runExecutionData: IRunExecutionData,
runIndex: number,
parentRunIndex: number,
connectionInputData: INodeExecutionData[],
inputData: ITaskDataConnections,
parentInputData: ITaskDataConnections,
additionalData: IWorkflowExecuteAdditionalData,
executeData: IExecuteData,
mode: WorkflowExecuteMode,
Expand All @@ -2034,10 +2034,13 @@ export async function getInputConnectionData(
itemIndex: number,
abortSignal?: AbortSignal,
): Promise<unknown> {
const node = this.getNode();
const nodeType = workflow.nodeTypes.getByNameAndVersion(node.type, node.typeVersion);
const parentNode = this.getNode();
const parentNodeType = workflow.nodeTypes.getByNameAndVersion(
parentNode.type,
parentNode.typeVersion,
);

const inputs = NodeHelpers.getNodeInputs(workflow, node, nodeType.description);
const inputs = NodeHelpers.getNodeInputs(workflow, parentNode, parentNodeType.description);

let inputConfiguration = inputs.find((input) => {
if (typeof input === 'string') {
Expand All @@ -2048,7 +2051,7 @@ export async function getInputConnectionData(

if (inputConfiguration === undefined) {
throw new ApplicationError('Node does not have input of type', {
extra: { nodeName: node.name, connectionType },
extra: { nodeName: parentNode.name, connectionType },
});
}

Expand All @@ -2059,14 +2062,14 @@ export async function getInputConnectionData(
}

const connectedNodes = workflow
.getParentNodes(node.name, connectionType, 1)
.getParentNodes(parentNode.name, connectionType, 1)
.map((nodeName) => workflow.getNode(nodeName) as INode)
.filter((connectedNode) => connectedNode.disabled !== true);

if (connectedNodes.length === 0) {
if (inputConfiguration.required) {
throw new NodeOperationError(
node,
parentNode,
`A ${inputConfiguration?.displayName ?? connectionType} sub-node must be connected and enabled`,
);
}
Expand All @@ -2078,82 +2081,86 @@ export async function getInputConnectionData(
connectedNodes.length > inputConfiguration.maxConnections
) {
throw new NodeOperationError(
node,
parentNode,
`Only ${inputConfiguration.maxConnections} ${connectionType} sub-nodes are/is allowed to be connected`,
);
}

const constParentNodes = connectedNodes.map(async (connectedNode) => {
const nodeType = workflow.nodeTypes.getByNameAndVersion(
const nodes: SupplyData[] = [];
for (const connectedNode of connectedNodes) {
const connectedNodeType = workflow.nodeTypes.getByNameAndVersion(
connectedNode.type,
connectedNode.typeVersion,
);
const context = new SupplyDataContext(
workflow,
connectedNode,
additionalData,
mode,
runExecutionData,
runIndex,
connectionInputData,
inputData,
executeData,
closeFunctions,
abortSignal,
);
const contextFactory = (runIndex: number, inputData: ITaskDataConnections) =>
new SupplyDataContext(
workflow,
connectedNode,
additionalData,
mode,
runExecutionData,
runIndex,
connectionInputData,
inputData,
connectionType,
executeData,
closeFunctions,
abortSignal,
);

if (!nodeType.supplyData) {
if (nodeType.description.outputs.includes(NodeConnectionType.AiTool)) {
nodeType.supplyData = async function (this: ISupplyDataFunctions) {
return createNodeAsTool(this, nodeType, this.getNode().parameters);
};
if (!connectedNodeType.supplyData) {
if (connectedNodeType.description.outputs.includes(NodeConnectionType.AiTool)) {
const supplyData = createNodeAsTool({
node: connectedNode,
nodeType: connectedNodeType,
contextFactory,
});
nodes.push(supplyData);
} else {
throw new ApplicationError('Node does not have a `supplyData` method defined', {
extra: { nodeName: connectedNode.name },
});
}
}
} else {
const context = contextFactory(parentRunIndex, parentInputData);
try {
const supplyData = await connectedNodeType.supplyData.call(context, itemIndex);
if (supplyData.closeFunction) {
closeFunctions.push(supplyData.closeFunction);
}
nodes.push(supplyData);
} catch (error) {
// Propagate errors from sub-nodes
if (error.functionality === 'configuration-node') throw error;
if (!(error instanceof ExecutionBaseError)) {
error = new NodeOperationError(connectedNode, error, {
itemIndex,
});
}

try {
const response = await nodeType.supplyData.call(context, itemIndex);
if (response.closeFunction) {
closeFunctions.push(response.closeFunction);
}
return response;
} catch (error) {
// Propagate errors from sub-nodes
if (error.functionality === 'configuration-node') throw error;
if (!(error instanceof ExecutionBaseError)) {
error = new NodeOperationError(connectedNode, error, {
let currentNodeRunIndex = 0;
if (runExecutionData.resultData.runData.hasOwnProperty(parentNode.name)) {
currentNodeRunIndex = runExecutionData.resultData.runData[parentNode.name].length;
}

// Display the error on the node which is causing it
await context.addExecutionDataFunctions(
'input',
error,
connectionType,
parentNode.name,
currentNodeRunIndex,
);

// Display on the calling node which node has the error
throw new NodeOperationError(connectedNode, `Error in sub-node ${connectedNode.name}`, {
itemIndex,
functionality: 'configuration-node',
description: error.message,
});
}

let currentNodeRunIndex = 0;
if (runExecutionData.resultData.runData.hasOwnProperty(node.name)) {
currentNodeRunIndex = runExecutionData.resultData.runData[node.name].length;
}

// Display the error on the node which is causing it
await context.addExecutionDataFunctions(
'input',
error,
connectionType,
node.name,
currentNodeRunIndex,
);

// Display on the calling node which node has the error
throw new NodeOperationError(connectedNode, `Error in sub-node ${connectedNode.name}`, {
itemIndex,
functionality: 'configuration-node',
description: error.message,
});
}
});

// Validate the inputs
const nodes = await Promise.all(constParentNodes);
}

return inputConfiguration.maxConnections === 1
? (nodes || [])[0]?.response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ describe('SupplyDataContext', () => {
runIndex,
connectionInputData,
inputData,
connectionType,
executeData,
[closeFn],
abortSignal,
Expand Down
Loading
Loading