From 7032bc46ef183285144dd31f47280b50846b297a Mon Sep 17 00:00:00 2001 From: Edwin Jose Date: Thu, 31 Oct 2024 13:31:03 -0400 Subject: [PATCH 1/2] feat: Add model_utils and model_constants - Enhanced the initialization logic in model_constants to handle delayed imports and circular dependencies. - Improved type hinting for better code clarity and maintainability. Details: - `get_model_info`: Retrieves comprehensive information about all available models, which is used to populate the `MODEL_INFO` dictionary. - `MODEL_INFO`: A dictionary where each key is a model identifier, and the value is a dictionary containing details about the model, such as its `display_name` and configuration options. - `PROVIDER_NAMES`: A list derived from `MODEL_INFO` that holds the names of model providers, providing a quick reference to all available model providers. --- .../base/langflow/base/models/model.py | 61 +++++++++++++++++++ .../langflow/base/models/model_constants.py | 17 ++++++ .../base/langflow/base/models/model_utils.py | 31 ++++++++++ .../tests/unit/base/models/__init__.py | 0 .../unit/base/models/test_model_constants.py | 25 ++++++++ 5 files changed, 134 insertions(+) create mode 100644 src/backend/base/langflow/base/models/model_constants.py create mode 100644 src/backend/base/langflow/base/models/model_utils.py create mode 100644 src/backend/tests/unit/base/models/__init__.py create mode 100644 src/backend/tests/unit/base/models/test_model_constants.py diff --git a/src/backend/base/langflow/base/models/model.py b/src/backend/base/langflow/base/models/model.py index f1ed9885def..d0a25fb3592 100644 --- a/src/backend/base/langflow/base/models/model.py +++ b/src/backend/base/langflow/base/models/model.py @@ -1,3 +1,4 @@ +import importlib import json import warnings from abc import abstractmethod @@ -206,3 +207,63 @@ def get_chat_result( @abstractmethod def build_model(self) -> LanguageModel: # type: ignore[type-var] """Implement this method to build the model.""" + + def get_llm(self, provider_name: str, model_info: dict[str, dict[str, str | list[InputTypes]]]) -> LanguageModel: + """Get LLM model based on provider name and inputs. + + Args: + provider_name: Name of the model provider (e.g., "OpenAI", "Azure OpenAI") + inputs: Dictionary of input parameters for the model + model_info: Dictionary of model information + + Returns: + Built LLM model instance + """ + try: + if provider_name not in [model.get("display_name") for model in model_info.values()]: + msg = f"Unknown model provider: {provider_name}" + raise ValueError(msg) + + # Find the component class name from MODEL_INFO in a single iteration + component_info, module_name = next( + ((info, key) for key, info in model_info.items() if info.get("display_name") == provider_name), + (None, None), + ) + if not component_info: + msg = f"Component information not found for {provider_name}" + raise ValueError(msg) + component_inputs = component_info.get("inputs", []) + # Get the component class from the models module + # Ensure component_inputs is a list of the expected types + if not isinstance(component_inputs, list): + component_inputs = [] + models_module = importlib.import_module("langflow.components.models") + component_class = getattr(models_module, str(module_name)) + component = component_class() + + return self.build_llm_model_from_inputs(component, component_inputs) + except Exception as e: + msg = f"Error building {provider_name} language model" + raise ValueError(msg) from e + + def build_llm_model_from_inputs( + self, component: Component, inputs: list[InputTypes], prefix: str = "" + ) -> LanguageModel: + """Build LLM model from component and inputs. + + Args: + component: LLM component instance + inputs: Dictionary of input parameters for the model + prefix: Prefix for the input names + Returns: + Built LLM model instance + """ + # Ensure prefix is a string + prefix = prefix or "" + # Filter inputs to only include valid component input names + input_data = { + str(component_input.name): getattr(self, f"{prefix}{component_input.name}", None) + for component_input in inputs + } + + return component.set(**input_data).build_model() diff --git a/src/backend/base/langflow/base/models/model_constants.py b/src/backend/base/langflow/base/models/model_constants.py new file mode 100644 index 00000000000..aa1f05766a2 --- /dev/null +++ b/src/backend/base/langflow/base/models/model_constants.py @@ -0,0 +1,17 @@ +class ModelConstants: + """Class to hold model-related constants. To solve circular import issue.""" + + PROVIDER_NAMES: list[str] = [] + MODEL_INFO: dict[str, dict[str, str | list]] = {} # Adjusted type hint + + @staticmethod + def initialize(): + from langflow.base.models.model_utils import get_model_info # Delayed import + + model_info = get_model_info() + ModelConstants.MODEL_INFO = model_info + ModelConstants.PROVIDER_NAMES = [ + str(model.get("display_name")) + for model in model_info.values() + if isinstance(model.get("display_name"), str) + ] diff --git a/src/backend/base/langflow/base/models/model_utils.py b/src/backend/base/langflow/base/models/model_utils.py new file mode 100644 index 00000000000..9157d873c1b --- /dev/null +++ b/src/backend/base/langflow/base/models/model_utils.py @@ -0,0 +1,31 @@ +import importlib + +from langflow.base.models.model import LCModelComponent +from langflow.inputs.inputs import InputTypes + + +def get_model_info() -> dict[str, dict[str, str | list[InputTypes]]]: + """Get inputs for all model components.""" + model_inputs = {} + models_module = importlib.import_module("langflow.components.models") + model_component_names = getattr(models_module, "__all__", []) + + for name in model_component_names: + if name in ("base", "DynamicLLMComponent"): # Skip the base module + continue + + component_class = getattr(models_module, name) + if issubclass(component_class, LCModelComponent): + component = component_class() + base_input_names = {input_field.name for input_field in LCModelComponent._base_inputs} + input_fields_list = [ + input_field for input_field in component.inputs if input_field.name not in base_input_names + ] + component_display_name = component.display_name + model_inputs[name] = { + "display_name": component_display_name, + "inputs": input_fields_list, + "icon": component.icon, + } + + return model_inputs diff --git a/src/backend/tests/unit/base/models/__init__.py b/src/backend/tests/unit/base/models/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/backend/tests/unit/base/models/test_model_constants.py b/src/backend/tests/unit/base/models/test_model_constants.py new file mode 100644 index 00000000000..27d0f4704d0 --- /dev/null +++ b/src/backend/tests/unit/base/models/test_model_constants.py @@ -0,0 +1,25 @@ +from src.backend.base.langflow.base.models.model_constants import ModelConstants + + +def test_provider_names(): + # Initialize the ModelConstants + ModelConstants.initialize() + + # Expected provider names + expected_provider_names = [ + "AIML", + "Amazon Bedrock", + "Anthropic", + "Azure OpenAI", + "Ollama", + "Vertex AI", + "Cohere", + "Google Generative AI", + "HuggingFace", + "OpenAI", + "Perplexity", + "Qianfan", + ] + + # Assert that the provider names match the expected list + assert expected_provider_names == ModelConstants.PROVIDER_NAMES From 308075bcd78a96c884e61dcfed83a412f3305b8e Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Fri, 1 Nov 2024 15:22:11 +0000 Subject: [PATCH 2/2] [autofix.ci] apply automated fixes --- src/backend/base/langflow/worker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/backend/base/langflow/worker.py b/src/backend/base/langflow/worker.py index 9d285741794..e57c99103a7 100644 --- a/src/backend/base/langflow/worker.py +++ b/src/backend/base/langflow/worker.py @@ -29,10 +29,10 @@ def build_vertex(self, vertex: Vertex) -> Vertex: @celery_app.task(acks_late=True) def process_graph_cached_task( - data_graph: dict[str, Any], # noqa: ARG001 - inputs: dict | list[dict] | None = None, # noqa: ARG001 - clear_cache=False, # noqa: ARG001, FBT002 - session_id=None, # noqa: ARG001 + data_graph: dict[str, Any], + inputs: dict | list[dict] | None = None, + clear_cache=False, # noqa: FBT002 + session_id=None, ) -> dict[str, Any]: msg = "This task is not implemented yet" raise NotImplementedError(msg)