From 212a566dfc90da54728b7ff2f3b391b02a6f5c91 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 16 Aug 2024 16:42:36 -0300 Subject: [PATCH] fix: Handle KeyError in template parameter mapping and suggest closest match if not found (#3366) --- .../langflow/custom/custom_component/component.py | 11 ++++++++++- src/backend/base/langflow/utils/util.py | 11 +++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index daa8f9cccac..992e3724a3d 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -16,6 +16,7 @@ from langflow.template.field.base import UNDEFINED, Input, Output from langflow.template.frontend_node.custom_components import ComponentFrontendNode from langflow.utils.async_helpers import run_until_complete +from langflow.utils.util import find_closest_match from .custom_component import CustomComponent @@ -415,7 +416,15 @@ def _map_parameters_on_frontend_node(self, frontend_node: ComponentFrontendNode) def _map_parameters_on_template(self, template: dict): for name, value in self._parameters.items(): - template[name]["value"] = value + try: + template[name]["value"] = value + except KeyError: + close_match = find_closest_match(name, list(template.keys())) + if close_match: + raise ValueError( + f"Parameter '{name}' not found in {self.__class__.__name__}. " f"Did you mean '{close_match}'?" + ) + raise ValueError(f"Parameter {name} not found in {self.__class__.__name__}. ") def _get_method_return_type(self, method_name: str) -> List[str]: method = getattr(self, method_name) diff --git a/src/backend/base/langflow/utils/util.py b/src/backend/base/langflow/utils/util.py index de70da2db35..ba9467edf50 100644 --- a/src/backend/base/langflow/utils/util.py +++ b/src/backend/base/langflow/utils/util.py @@ -1,3 +1,4 @@ +import difflib import importlib import inspect import json @@ -465,3 +466,13 @@ def is_class_method(func, cls): def escape_json_dump(edge_dict): return json.dumps(edge_dict).replace('"', "œ") + + +def find_closest_match(string: str, list_of_strings: list[str]) -> str | None: + """ + Find the closest match in a list of strings. + """ + closest_match = difflib.get_close_matches(string, list_of_strings, n=1, cutoff=0.2) + if closest_match: + return closest_match[0] + return None