From 849c7ac6e3185c75c43cf28005244b08d9adced3 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Tue, 15 Oct 2024 12:02:37 +0200 Subject: [PATCH] Auto-fix ruff rules in tests --- scripts/ci/pypi_nightly_tag.py | 13 +- scripts/ci/update_lf_base_dependency.py | 23 ++-- scripts/ci/update_pyproject_name.py | 25 ++-- scripts/ci/update_pyproject_version.py | 26 ++-- scripts/ci/update_uv_dependency.py | 15 ++- scripts/factory_restart_space.py | 5 +- src/backend/langflow/version/version.py | 16 +-- src/backend/tests/api_keys.py | 6 +- src/backend/tests/conftest.py | 58 ++++---- .../components/astra/test_astra_component.py | 20 ++- .../helpers/test_parse_json_data.py | 3 +- .../components/inputs/test_chat_input.py | 5 +- .../components/inputs/test_text_input.py | 6 +- .../integration/components/mock_components.py | 5 +- .../output_parsers/test_output_parser.py | 4 +- .../components/outputs/test_chat_output.py | 3 +- .../components/outputs/test_text_output.py | 3 +- .../components/prompts/test_prompt.py | 4 +- .../integration/flows/test_basic_prompting.py | 1 - src/backend/tests/integration/test_misc.py | 6 +- src/backend/tests/integration/utils.py | 57 ++++---- src/backend/tests/locust/locustfile.py | 14 +- src/backend/tests/unit/api/test_api_utils.py | 3 +- .../tests/unit/api/v1/test_variable.py | 43 +++--- src/backend/tests/unit/base/load/test_load.py | 1 + .../unit/base/tools/test_component_toolkit.py | 3 +- .../models/test_ChatOllama_component.py | 1 - .../prompts/test_prompt_component.py | 1 - .../components/tools/test_python_repl_tool.py | 1 - .../components/tools/test_yfinance_tool.py | 1 - .../component/test_component_to_tool.py | 1 - .../custom/custom_component/test_component.py | 1 - .../tests/unit/events/test_event_manager.py | 1 - src/backend/tests/unit/exceptions/test_api.py | 3 +- .../tests/unit/graph/edge/test_edge_base.py | 1 - .../graph/graph/state/test_state_model.py | 3 +- .../tests/unit/graph/graph/test_base.py | 3 +- .../unit/graph/graph/test_callback_graph.py | 1 - .../tests/unit/graph/graph/test_cycles.py | 3 - .../graph/graph/test_graph_state_model.py | 7 +- .../graph/test_runnable_vertices_manager.py | 12 +- .../tests/unit/graph/graph/test_utils.py | 1 - src/backend/tests/unit/graph/test_graph.py | 7 +- .../starter_projects/test_memory_chatbot.py | 9 +- .../starter_projects/test_vector_store_rag.py | 20 ++- src/backend/tests/unit/inputs/test_inputs.py | 3 +- src/backend/tests/unit/io/test_io_schema.py | 7 +- .../tests/unit/io/test_table_schema.py | 1 - .../tests/unit/schema/test_schema_message.py | 1 - .../unit/services/variable/test_service.py | 3 +- src/backend/tests/unit/test_api_key.py | 4 +- src/backend/tests/unit/test_chat_endpoint.py | 9 +- src/backend/tests/unit/test_cli.py | 1 - .../tests/unit/test_custom_component.py | 124 +++++------------- .../unit/test_custom_component_with_client.py | 3 +- src/backend/tests/unit/test_data_class.py | 1 - .../tests/unit/test_data_components.py | 9 +- src/backend/tests/unit/test_database.py | 7 +- src/backend/tests/unit/test_endpoints.py | 27 ++-- .../unit/test_experimental_components.py | 2 +- src/backend/tests/unit/test_files.py | 5 +- src/backend/tests/unit/test_frontend_nodes.py | 1 - .../tests/unit/test_helper_components.py | 4 +- src/backend/tests/unit/test_initial_setup.py | 8 +- .../tests/unit/test_kubernetes_secrets.py | 6 +- src/backend/tests/unit/test_loading.py | 6 +- src/backend/tests/unit/test_logger.py | 13 +- src/backend/tests/unit/test_login.py | 3 +- src/backend/tests/unit/test_messages.py | 19 +-- .../tests/unit/test_messages_endpoints.py | 12 +- src/backend/tests/unit/test_process.py | 1 - src/backend/tests/unit/test_schema.py | 17 ++- src/backend/tests/unit/test_telemetry.py | 8 +- src/backend/tests/unit/test_template.py | 6 +- src/backend/tests/unit/test_user.py | 5 +- src/backend/tests/unit/test_validate_code.py | 8 +- src/backend/tests/unit/test_version.py | 12 +- .../utils/test_connection_string_parser.py | 2 +- .../unit/utils/test_format_directory_path.py | 2 +- .../unit/utils/test_rewrite_file_path.py | 4 +- .../unit/utils/test_truncate_long_strings.py | 6 +- .../test_truncate_long_strings_on_objects.py | 6 +- 82 files changed, 351 insertions(+), 450 deletions(-) diff --git a/scripts/ci/pypi_nightly_tag.py b/scripts/ci/pypi_nightly_tag.py index f117427366b6..4ac02fc2cff0 100755 --- a/scripts/ci/pypi_nightly_tag.py +++ b/scripts/ci/pypi_nightly_tag.py @@ -1,7 +1,5 @@ #!/usr/bin/env python -""" -Idea from https://github.com/streamlit/streamlit/blob/4841cf91f1c820a392441092390c4c04907f9944/scripts/pypi_nightly_create_tag.py -""" +"""Idea from https://github.com/streamlit/streamlit/blob/4841cf91f1c820a392441092390c4c04907f9944/scripts/pypi_nightly_create_tag.py.""" import sys @@ -24,13 +22,15 @@ def get_latest_published_version(build_type: str, is_nightly: bool) -> Version: elif build_type == "main": url = PYPI_LANGFLOW_NIGHTLY_URL if is_nightly else PYPI_LANGFLOW_URL else: - raise ValueError(f"Invalid build type: {build_type}") + msg = f"Invalid build type: {build_type}" + raise ValueError(msg) res = requests.get(url) try: version_str = res.json()["info"]["version"] except Exception as e: - raise RuntimeError("Got unexpected response from PyPI", e) + msg = "Got unexpected response from PyPI" + raise RuntimeError(msg, e) return Version(version_str) @@ -75,7 +75,8 @@ def create_tag(build_type: str): if __name__ == "__main__": if len(sys.argv) != 2: - raise Exception("Specify base or main") + msg = "Specify base or main" + raise Exception(msg) build_type = sys.argv[1] tag = create_tag(build_type) diff --git a/scripts/ci/update_lf_base_dependency.py b/scripts/ci/update_lf_base_dependency.py index adf4d1a338c4..ed8269f3d8df 100755 --- a/scripts/ci/update_lf_base_dependency.py +++ b/scripts/ci/update_lf_base_dependency.py @@ -1,6 +1,6 @@ import os -import sys import re +import sys import packaging.version @@ -10,7 +10,7 @@ def update_base_dep(pyproject_path: str, new_version: str) -> None: """Update the langflow-base dependency in pyproject.toml.""" filepath = os.path.join(BASE_DIR, pyproject_path) - with open(filepath, "r") as file: + with open(filepath, encoding="utf-8") as file: content = file.read() replacement = f'langflow-base-nightly = "{new_version}"' @@ -18,33 +18,32 @@ def update_base_dep(pyproject_path: str, new_version: str) -> None: # Updates the pattern for poetry pattern = re.compile(r'langflow-base = \{ path = "\./src/backend/base", develop = true \}') if not pattern.search(content): - raise Exception(f'langflow-base poetry dependency not found in "{filepath}"') + msg = f'langflow-base poetry dependency not found in "{filepath}"' + raise Exception(msg) content = pattern.sub(replacement, content) - with open(filepath, "w") as file: + with open(filepath, "w", encoding="utf-8") as file: file.write(content) def verify_pep440(version): - """ - Verify if version is PEP440 compliant. + """Verify if version is PEP440 compliant. https://github.com/pypa/packaging/blob/16.7/packaging/version.py#L191 """ - try: return packaging.version.Version(version) - except packaging.version.InvalidVersion as e: - raise e + except packaging.version.InvalidVersion: + raise def main() -> None: if len(sys.argv) != 2: - raise Exception("New version not specified") + msg = "New version not specified" + raise Exception(msg) base_version = sys.argv[1] # Strip "v" prefix from version if present - if base_version.startswith("v"): - base_version = base_version[1:] + base_version = base_version.removeprefix("v") verify_pep440(base_version) update_base_dep("pyproject.toml", base_version) diff --git a/scripts/ci/update_pyproject_name.py b/scripts/ci/update_pyproject_name.py index e846ec87d234..86e9b3e89f06 100755 --- a/scripts/ci/update_pyproject_name.py +++ b/scripts/ci/update_pyproject_name.py @@ -1,6 +1,6 @@ import os -import sys import re +import sys BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) @@ -8,24 +8,25 @@ def update_pyproject_name(pyproject_path: str, new_project_name: str) -> None: """Update the project name in pyproject.toml.""" filepath = os.path.join(BASE_DIR, pyproject_path) - with open(filepath, "r") as file: + with open(filepath, encoding="utf-8") as file: content = file.read() # Regex to match the version line under [tool.poetry] pattern = re.compile(r'(?<=^name = ")[^"]+(?=")', re.MULTILINE) if not pattern.search(content): - raise Exception(f'Project name not found in "{filepath}"') + msg = f'Project name not found in "{filepath}"' + raise Exception(msg) content = pattern.sub(new_project_name, content) - with open(filepath, "w") as file: + with open(filepath, "w", encoding="utf-8") as file: file.write(content) def update_uv_dep(pyproject_path: str, new_project_name: str) -> None: """Update the langflow-base dependency in pyproject.toml.""" filepath = os.path.join(BASE_DIR, pyproject_path) - with open(filepath, "r") as file: + with open(filepath, encoding="utf-8") as file: content = file.read() if new_project_name == "langflow-nightly": @@ -35,19 +36,22 @@ def update_uv_dep(pyproject_path: str, new_project_name: str) -> None: pattern = re.compile(r"langflow-base = \{ workspace = true \}") replacement = "langflow-base-nightly = { workspace = true }" else: - raise ValueError(f"Invalid project name: {new_project_name}") + msg = f"Invalid project name: {new_project_name}" + raise ValueError(msg) # Updates the dependency name for uv if not pattern.search(content): - raise Exception(f"{replacement} uv dependency not found in {filepath}") + msg = f"{replacement} uv dependency not found in {filepath}" + raise Exception(msg) content = pattern.sub(replacement, content) - with open(filepath, "w") as file: + with open(filepath, "w", encoding="utf-8") as file: file.write(content) def main() -> None: if len(sys.argv) != 3: - raise Exception("Must specify project name and build type, e.g. langflow-nightly base") + msg = "Must specify project name and build type, e.g. langflow-nightly base" + raise Exception(msg) new_project_name = sys.argv[1] build_type = sys.argv[2] @@ -58,7 +62,8 @@ def main() -> None: update_pyproject_name("pyproject.toml", new_project_name) update_uv_dep("pyproject.toml", new_project_name) else: - raise ValueError(f"Invalid build type: {build_type}") + msg = f"Invalid build type: {build_type}" + raise ValueError(msg) if __name__ == "__main__": diff --git a/scripts/ci/update_pyproject_version.py b/scripts/ci/update_pyproject_version.py index 93c4511d1728..c65690e6cbf5 100755 --- a/scripts/ci/update_pyproject_version.py +++ b/scripts/ci/update_pyproject_version.py @@ -1,6 +1,6 @@ import os -import sys import re +import sys import packaging.version @@ -10,42 +10,41 @@ def update_pyproject_version(pyproject_path: str, new_version: str) -> None: """Update the version in pyproject.toml.""" filepath = os.path.join(BASE_DIR, pyproject_path) - with open(filepath, "r") as file: + with open(filepath, encoding="utf-8") as file: content = file.read() # Regex to match the version line under [tool.poetry] pattern = re.compile(r'(?<=^version = ")[^"]+(?=")', re.MULTILINE) if not pattern.search(content): - raise Exception(f'Project version not found in "{filepath}"') + msg = f'Project version not found in "{filepath}"' + raise Exception(msg) content = pattern.sub(new_version, content) - with open(filepath, "w") as file: + with open(filepath, "w", encoding="utf-8") as file: file.write(content) def verify_pep440(version): - """ - Verify if version is PEP440 compliant. + """Verify if version is PEP440 compliant. https://github.com/pypa/packaging/blob/16.7/packaging/version.py#L191 """ - try: return packaging.version.Version(version) - except packaging.version.InvalidVersion as e: - raise e + except packaging.version.InvalidVersion: + raise def main() -> None: if len(sys.argv) != 3: - raise Exception("New version not specified") + msg = "New version not specified" + raise Exception(msg) new_version = sys.argv[1] # Strip "v" prefix from version if present - if new_version.startswith("v"): - new_version = new_version[1:] + new_version = new_version.removeprefix("v") build_type = sys.argv[2] @@ -56,7 +55,8 @@ def main() -> None: elif build_type == "main": update_pyproject_version("pyproject.toml", new_version) else: - raise ValueError(f"Invalid build type: {build_type}") + msg = f"Invalid build type: {build_type}" + raise ValueError(msg) if __name__ == "__main__": diff --git a/scripts/ci/update_uv_dependency.py b/scripts/ci/update_uv_dependency.py index 6f3032a526c6..41fce4524a07 100755 --- a/scripts/ci/update_uv_dependency.py +++ b/scripts/ci/update_uv_dependency.py @@ -1,38 +1,39 @@ import os -import sys import re +import sys BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) def update_uv_dep(base_version: str) -> None: """Update the langflow-base dependency in pyproject.toml.""" - pyproject_path = os.path.join(BASE_DIR, "pyproject.toml") # Read the pyproject.toml file content - with open(pyproject_path, "r") as file: + with open(pyproject_path, encoding="utf-8") as file: content = file.read() # For the main project, update the langflow-base dependency in the UV section pattern = re.compile(r'(dependencies\s*=\s*\[\s*\n\s*)("langflow-base==[\d.]+")') - replacement = r'\1"langflow-base-nightly=={}"'.format(base_version) + replacement = rf'\1"langflow-base-nightly=={base_version}"' # Check if the pattern is found if not pattern.search(content): - raise Exception(f"{pattern} UV dependency not found in {pyproject_path}") + msg = f"{pattern} UV dependency not found in {pyproject_path}" + raise Exception(msg) # Replace the matched pattern with the new one content = pattern.sub(replacement, content) # Write the updated content back to the file - with open(pyproject_path, "w") as file: + with open(pyproject_path, "w", encoding="utf-8") as file: file.write(content) def main() -> None: if len(sys.argv) != 2: - raise Exception("specify base version") + msg = "specify base version" + raise Exception(msg) base_version = sys.argv[1] base_version = base_version.lstrip("v") update_uv_dep(base_version) diff --git a/scripts/factory_restart_space.py b/scripts/factory_restart_space.py index 50cf86163e26..9006d5fa7815 100644 --- a/scripts/factory_restart_space.py +++ b/scripts/factory_restart_space.py @@ -6,6 +6,7 @@ # ] # /// import argparse +import sys from huggingface_hub import HfApi, list_models from rich import print @@ -23,11 +24,11 @@ if not space: print("Please provide a space to restart.") - exit() + sys.exit() if not parsed_args.token: print("Please provide an API token.") - exit() + sys.exit() # Or configure a HfApi client hf_api = HfApi( diff --git a/src/backend/langflow/version/version.py b/src/backend/langflow/version/version.py index 40619003c8a3..2cd42e3977f1 100644 --- a/src/backend/langflow/version/version.py +++ b/src/backend/langflow/version/version.py @@ -1,6 +1,8 @@ +import contextlib + + def get_version() -> str: - """ - Retrieves the version of the package from a possible list of package names. + """Retrieves the version of the package from a possible list of package names. This accounts for after package names are updated for -nightly builds. Returns: @@ -19,20 +21,18 @@ def get_version() -> str: ] _version = None for pkg_name in pkg_names: - try: + with contextlib.suppress(ImportError, metadata.PackageNotFoundError): _version = metadata.version(pkg_name) - except (ImportError, metadata.PackageNotFoundError): - pass if _version is None: - raise ValueError(f"Package not found from options {pkg_names}") + msg = f"Package not found from options {pkg_names}" + raise ValueError(msg) return _version def is_pre_release(v: str) -> bool: - """ - Returns a boolean indicating whether the version is a pre-release version, + """Returns a boolean indicating whether the version is a pre-release version, as per the definition of a pre-release segment from PEP 440. """ return any(label in v for label in ["a", "b", "rc"]) diff --git a/src/backend/tests/api_keys.py b/src/backend/tests/api_keys.py index 6dac302a13eb..42260f744a51 100644 --- a/src/backend/tests/api_keys.py +++ b/src/backend/tests/api_keys.py @@ -4,8 +4,7 @@ def get_required_env_var(var: str) -> str: - """ - Get the value of the specified environment variable. + """Get the value of the specified environment variable. Args: var (str): The environment variable to get. @@ -18,7 +17,8 @@ def get_required_env_var(var: str) -> str: """ value = os.getenv(var) if not value: - raise ValueError(f"Environment variable {var} is not set") + msg = f"Environment variable {var} is not set" + raise ValueError(msg) return value diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index c776064e54fd..dd0d7907bfee 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -17,13 +17,6 @@ from dotenv import load_dotenv from fastapi.testclient import TestClient from httpx import ASGITransport, AsyncClient -from loguru import logger -from pytest import LogCaptureFixture -from sqlmodel import Session, SQLModel, create_engine, select -from sqlmodel.pool import StaticPool -from tests.api_keys import get_openai_api_key -from typer.testing import CliRunner - from langflow.graph.graph.base import Graph from langflow.initial_setup.setup import STARTER_FOLDER_NAME from langflow.services.auth.utils import get_password_hash @@ -35,6 +28,13 @@ from langflow.services.database.models.vertex_builds.crud import delete_vertex_builds_by_flow_id from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service +from loguru import logger +from pytest import LogCaptureFixture +from sqlmodel import Session, SQLModel, create_engine, select +from sqlmodel.pool import StaticPool +from typer.testing import CliRunner + +from tests.api_keys import get_openai_api_key if TYPE_CHECKING: from langflow.services.database.service import DatabaseService @@ -115,7 +115,7 @@ def caplog(caplog: LogCaptureFixture): logger.remove(handler_id) -@pytest.fixture() +@pytest.fixture async def async_client() -> AsyncGenerator: from langflow.main import create_app @@ -189,8 +189,7 @@ def distributed_client_fixture(session: Session, monkeypatch, distributed_env): def get_graph(_type="basic"): - """Get a graph from a json file""" - + """Get a graph from a json file.""" if _type == "basic": path = pytest.BASIC_EXAMPLE_PATH elif _type == "complex": @@ -198,7 +197,7 @@ def get_graph(_type="basic"): elif _type == "openapi": path = pytest.OPENAPI_EXAMPLE_PATH - with open(path) as f: + with open(path, encoding="utf-8") as f: flow_graph = json.load(f) data_graph = flow_graph["data"] nodes = data_graph["nodes"] @@ -210,13 +209,13 @@ def get_graph(_type="basic"): @pytest.fixture def basic_graph_data(): - with open(pytest.BASIC_EXAMPLE_PATH) as f: + with open(pytest.BASIC_EXAMPLE_PATH, encoding="utf-8") as f: return json.load(f) @pytest.fixture def basic_graph(): - yield get_graph() + return get_graph() @pytest.fixture @@ -231,55 +230,55 @@ def openapi_graph(): @pytest.fixture def json_flow(): - with open(pytest.BASIC_EXAMPLE_PATH) as f: + with open(pytest.BASIC_EXAMPLE_PATH, encoding="utf-8") as f: return f.read() @pytest.fixture def grouped_chat_json_flow(): - with open(pytest.GROUPED_CHAT_EXAMPLE_PATH) as f: + with open(pytest.GROUPED_CHAT_EXAMPLE_PATH, encoding="utf-8") as f: return f.read() @pytest.fixture def one_grouped_chat_json_flow(): - with open(pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH) as f: + with open(pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH, encoding="utf-8") as f: return f.read() @pytest.fixture def vector_store_grouped_json_flow(): - with open(pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH) as f: + with open(pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH, encoding="utf-8") as f: return f.read() @pytest.fixture def json_flow_with_prompt_and_history(): - with open(pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY) as f: + with open(pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY, encoding="utf-8") as f: return f.read() @pytest.fixture def json_simple_api_test(): - with open(pytest.SIMPLE_API_TEST) as f: + with open(pytest.SIMPLE_API_TEST, encoding="utf-8") as f: return f.read() @pytest.fixture def json_vector_store(): - with open(pytest.VECTOR_STORE_PATH) as f: + with open(pytest.VECTOR_STORE_PATH, encoding="utf-8") as f: return f.read() @pytest.fixture def json_webhook_test(): - with open(pytest.WEBHOOK_TEST) as f: + with open(pytest.WEBHOOK_TEST, encoding="utf-8") as f: return f.read() @pytest.fixture def json_memory_chatbot_no_llm(): - with open(pytest.MEMORY_CHATBOT_NO_LLM) as f: + with open(pytest.MEMORY_CHATBOT_NO_LLM, encoding="utf-8") as f: return f.read() @@ -325,12 +324,12 @@ def blank_session_getter(db_service: "DatabaseService"): with Session(db_service.engine) as session: yield session - yield blank_session_getter + return blank_session_getter @pytest.fixture def runner(): - yield CliRunner() + return CliRunner() @pytest.fixture @@ -347,7 +346,7 @@ async def test_user(client): await client.delete(f"/api/v1/users/{user['id']}") -@pytest.fixture(scope="function") +@pytest.fixture def active_user(client): db_manager = get_db_service() with db_manager.with_session() as session: @@ -382,7 +381,7 @@ async def logged_in_headers(client, active_user): assert response.status_code == 200 tokens = response.json() a_token = tokens["access_token"] - yield {"Authorization": f"Bearer {a_token}"} + return {"Authorization": f"Bearer {a_token}"} @pytest.fixture @@ -405,13 +404,13 @@ def flow(client, json_flow: str, active_user): @pytest.fixture def json_chat_input(): - with open(pytest.CHAT_INPUT) as f: + with open(pytest.CHAT_INPUT, encoding="utf-8") as f: yield f.read() @pytest.fixture def json_two_outputs(): - with open(pytest.TWO_OUTPUTS) as f: + with open(pytest.TWO_OUTPUTS, encoding="utf-8") as f: yield f.read() @@ -540,7 +539,8 @@ def get_starter_project(active_user): .where(Flow.name == "Basic Prompting (Hello, World)") ).first() if not flow: - raise ValueError("No starter project found") + msg = "No starter project found" + raise ValueError(msg) # ensure openai api key is set get_openai_api_key() diff --git a/src/backend/tests/integration/components/astra/test_astra_component.py b/src/backend/tests/integration/components/astra/test_astra_component.py index 57fa80760e93..06da3ec54cd8 100644 --- a/src/backend/tests/integration/components/astra/test_astra_component.py +++ b/src/backend/tests/integration/components/astra/test_astra_component.py @@ -1,18 +1,14 @@ import os -from astrapy.db import AstraDB import pytest - +from astrapy.db import AstraDB +from langchain_core.documents import Document from langflow.components.embeddings import OpenAIEmbeddingsComponent from langflow.components.vectorstores import AstraVectorStoreComponent -from tests.api_keys import get_astradb_application_token, get_astradb_api_endpoint, get_openai_api_key -from tests.integration.components.mock_components import TextToData -from tests.integration.utils import ComponentInputHandle -from langchain_core.documents import Document - - from langflow.schema.data import Data -from tests.integration.utils import run_single_component +from tests.api_keys import get_astradb_api_endpoint, get_astradb_application_token, get_openai_api_key +from tests.integration.components.mock_components import TextToData +from tests.integration.utils import ComponentInputHandle, run_single_component BASIC_COLLECTION = "test_basic" SEARCH_COLLECTION = "test_search" @@ -30,7 +26,7 @@ ] -@pytest.fixture() +@pytest.fixture def astradb_client(request): client = AstraDB(api_endpoint=get_astradb_api_endpoint(), token=get_astradb_application_token()) yield client @@ -139,7 +135,7 @@ def test_astra_vectorize(): @pytest.mark.api_key_required def test_astra_vectorize_with_provider_api_key(): - """tests vectorize using an openai api key""" + """Tests vectorize using an openai api key.""" from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions application_token = get_astradb_application_token() @@ -196,7 +192,7 @@ def test_astra_vectorize_with_provider_api_key(): @pytest.mark.api_key_required def test_astra_vectorize_passes_authentication(): - """tests vectorize using the authentication parameter""" + """Tests vectorize using the authentication parameter.""" from langchain_astradb import AstraDBVectorStore, CollectionVectorServiceOptions store = None diff --git a/src/backend/tests/integration/components/helpers/test_parse_json_data.py b/src/backend/tests/integration/components/helpers/test_parse_json_data.py index 9b3d13ebdcea..5c9931c97f96 100644 --- a/src/backend/tests/integration/components/helpers/test_parse_json_data.py +++ b/src/backend/tests/integration/components/helpers/test_parse_json_data.py @@ -1,10 +1,9 @@ import pytest - from langflow.components.helpers.ParseJSONData import ParseJSONDataComponent from langflow.components.inputs import ChatInput from langflow.schema import Data from tests.integration.components.mock_components import TextToData -from tests.integration.utils import run_single_component, ComponentInputHandle +from tests.integration.utils import ComponentInputHandle, run_single_component @pytest.mark.asyncio diff --git a/src/backend/tests/integration/components/inputs/test_chat_input.py b/src/backend/tests/integration/components/inputs/test_chat_input.py index 7a17b1e03d2a..f308685e6d8a 100644 --- a/src/backend/tests/integration/components/inputs/test_chat_input.py +++ b/src/backend/tests/integration/components/inputs/test_chat_input.py @@ -1,10 +1,9 @@ +import pytest +from langflow.components.inputs import ChatInput from langflow.memory import get_messages from langflow.schema.message import Message from tests.integration.utils import run_single_component -from langflow.components.inputs import ChatInput -import pytest - @pytest.mark.asyncio async def test_default(): diff --git a/src/backend/tests/integration/components/inputs/test_text_input.py b/src/backend/tests/integration/components/inputs/test_text_input.py index c0379d794208..c08d65067bd6 100644 --- a/src/backend/tests/integration/components/inputs/test_text_input.py +++ b/src/backend/tests/integration/components/inputs/test_text_input.py @@ -1,14 +1,12 @@ +import pytest +from langflow.components.inputs import TextInputComponent from langflow.schema.message import Message from tests.integration.utils import run_single_component -from langflow.components.inputs import TextInputComponent -import pytest - @pytest.mark.asyncio async def test_text_input(): outputs = await run_single_component(TextInputComponent, run_input="sample text", input_type="text") - print(outputs) assert isinstance(outputs["text"], Message) assert outputs["text"].text == "sample text" assert outputs["text"].sender is None diff --git a/src/backend/tests/integration/components/mock_components.py b/src/backend/tests/integration/components/mock_components.py index dc81594e56e3..2bf304304e45 100644 --- a/src/backend/tests/integration/components/mock_components.py +++ b/src/backend/tests/integration/components/mock_components.py @@ -1,8 +1,7 @@ import json -from typing import List from langflow.custom import Component -from langflow.inputs import StrInput, BoolInput +from langflow.inputs import BoolInput, StrInput from langflow.schema import Data from langflow.template import Output @@ -21,5 +20,5 @@ def _to_data(self, text: str) -> Data: return Data(data=json.loads(text)) return Data(text=text) - def create_data(self) -> List[Data]: + def create_data(self) -> list[Data]: return [self._to_data(t) for t in self.text_data] diff --git a/src/backend/tests/integration/components/output_parsers/test_output_parser.py b/src/backend/tests/integration/components/output_parsers/test_output_parser.py index c4668ef22696..f778b7a4ced4 100644 --- a/src/backend/tests/integration/components/output_parsers/test_output_parser.py +++ b/src/backend/tests/integration/components/output_parsers/test_output_parser.py @@ -1,6 +1,6 @@ import os -import pytest +import pytest from langflow.components.models.OpenAIModel import OpenAIModelComponent from langflow.components.output_parsers.OutputParser import OutputParserComponent from langflow.components.prompts.Prompt import PromptComponent @@ -23,7 +23,7 @@ async def test_csv_output_parser_openai(): prompt_handler = ComponentInputHandle( clazz=PromptComponent, inputs={ - "template": "List the first five positive integers.\n\n{format_instructions}", + "template": f"List the first five positive integers.\n\n{format_instructions}", "format_instructions": format_instructions, }, output_name="prompt", diff --git a/src/backend/tests/integration/components/outputs/test_chat_output.py b/src/backend/tests/integration/components/outputs/test_chat_output.py index d2b3ace550db..d5ca1de58508 100644 --- a/src/backend/tests/integration/components/outputs/test_chat_output.py +++ b/src/backend/tests/integration/components/outputs/test_chat_output.py @@ -1,10 +1,9 @@ +import pytest from langflow.components.outputs import ChatOutput from langflow.memory import get_messages from langflow.schema.message import Message from tests.integration.utils import run_single_component -import pytest - @pytest.mark.asyncio async def test_string(): diff --git a/src/backend/tests/integration/components/outputs/test_text_output.py b/src/backend/tests/integration/components/outputs/test_text_output.py index 5c0e2cdb424f..87e027d33552 100644 --- a/src/backend/tests/integration/components/outputs/test_text_output.py +++ b/src/backend/tests/integration/components/outputs/test_text_output.py @@ -1,9 +1,8 @@ +import pytest from langflow.components.outputs import TextOutputComponent from langflow.schema.message import Message from tests.integration.utils import run_single_component -import pytest - @pytest.mark.asyncio async def test(): diff --git a/src/backend/tests/integration/components/prompts/test_prompt.py b/src/backend/tests/integration/components/prompts/test_prompt.py index 31ae9aa81aa5..744653269a17 100644 --- a/src/backend/tests/integration/components/prompts/test_prompt.py +++ b/src/backend/tests/integration/components/prompts/test_prompt.py @@ -1,13 +1,11 @@ +import pytest from langflow.components.prompts import PromptComponent from langflow.schema.message import Message from tests.integration.utils import run_single_component -import pytest - @pytest.mark.asyncio async def test(): outputs = await run_single_component(PromptComponent, inputs={"template": "test {var1}", "var1": "from the var"}) - print(outputs) assert isinstance(outputs["prompt"], Message) assert outputs["prompt"].text == "test from the var" diff --git a/src/backend/tests/integration/flows/test_basic_prompting.py b/src/backend/tests/integration/flows/test_basic_prompting.py index ef0f773d9b03..47298099919b 100644 --- a/src/backend/tests/integration/flows/test_basic_prompting.py +++ b/src/backend/tests/integration/flows/test_basic_prompting.py @@ -1,5 +1,4 @@ import pytest - from langflow.components.inputs import ChatInput from langflow.components.outputs import ChatOutput from langflow.components.prompts import PromptComponent diff --git a/src/backend/tests/integration/test_misc.py b/src/backend/tests/integration/test_misc.py index e15cd26d8464..5d10fe1fdf5b 100644 --- a/src/backend/tests/integration/test_misc.py +++ b/src/backend/tests/integration/test_misc.py @@ -3,7 +3,6 @@ import pytest from fastapi import status from fastapi.testclient import TestClient - from langflow.graph.schema import RunOutputs from langflow.initial_setup.setup import load_starter_projects from langflow.load import run_flow_from_json @@ -80,9 +79,8 @@ async def test_run_with_inputs_and_outputs(client, starter_project, created_api_ @pytest.mark.noclient @pytest.mark.api_key_required def test_run_flow_from_json_object(): - """Test loading a flow from a json file and applying tweaks""" - _, projects = zip(*load_starter_projects()) - project = [project for project in projects if "Basic Prompting" in project["name"]][0] + """Test loading a flow from a json file and applying tweaks.""" + project = next(project for _, project in load_starter_projects() if "Basic Prompting" in project["name"]) results = run_flow_from_json(project, input_value="test", fallback_to_env_vars=True) assert results is not None assert all(isinstance(result, RunOutputs) for result in results) diff --git a/src/backend/tests/integration/utils.py b/src/backend/tests/integration/utils.py index 8203dd60459a..7bc88251f3ba 100644 --- a/src/backend/tests/integration/utils.py +++ b/src/backend/tests/integration/utils.py @@ -1,21 +1,19 @@ import dataclasses import os import uuid -from typing import Optional, Any +from typing import Any +import requests from astrapy.admin import parse_api_endpoint - from langflow.api.v1.schemas import InputValueRequest from langflow.custom import Component from langflow.field_typing import Embeddings from langflow.graph import Graph from langflow.processing.process import run_graph_internal -import requests def check_env_vars(*vars): - """ - Check if all specified environment variables are set. + """Check if all specified environment variables are set. Args: *vars (str): The environment variables to check. @@ -27,8 +25,7 @@ def check_env_vars(*vars): def valid_nvidia_vectorize_region(api_endpoint: str) -> bool: - """ - Check if the specified region is valid. + """Check if the specified region is valid. Args: region (str): The region to check. @@ -38,8 +35,9 @@ def valid_nvidia_vectorize_region(api_endpoint: str) -> bool: """ parsed_endpoint = parse_api_endpoint(api_endpoint) if not parsed_endpoint: - raise ValueError("Invalid ASTRA_DB_API_ENDPOINT") - return parsed_endpoint.region in ["us-east-2"] + msg = "Invalid ASTRA_DB_API_ENDPOINT" + raise ValueError(msg) + return parsed_endpoint.region == "us-east-2" class MockEmbeddings(Embeddings): @@ -70,15 +68,15 @@ def get_components_by_type(self, component_type): if node["data"]["type"] == component_type: result.append(node["id"]) if not result: - raise ValueError( - f"Component of type {component_type} not found, available types: {', '.join(set(node['data']['type'] for node in self.json['data']['nodes']))}" - ) + msg = f"Component of type {component_type} not found, available types: {', '.join({node['data']['type'] for node in self.json['data']['nodes']})}" + raise ValueError(msg) return result def get_component_by_type(self, component_type): components = self.get_components_by_type(component_type) if len(components) > 1: - raise ValueError(f"Multiple components of type {component_type} found") + msg = f"Multiple components of type {component_type} found" + raise ValueError(msg) return components[0] def set_value(self, component_id, key, value): @@ -86,13 +84,15 @@ def set_value(self, component_id, key, value): for node in self.json["data"]["nodes"]: if node["id"] == component_id: if key not in node["data"]["node"]["template"]: - raise ValueError(f"Component {component_id} does not have input {key}") + msg = f"Component {component_id} does not have input {key}" + raise ValueError(msg) node["data"]["node"]["template"][key]["value"] = value node["data"]["node"]["template"][key]["load_from_db"] = False done = True break if not done: - raise ValueError(f"Component {component_id} not found") + msg = f"Component {component_id} not found" + raise ValueError(msg) def download_flow_from_github(name: str, version: str) -> JSONFlow: @@ -105,18 +105,15 @@ def download_flow_from_github(name: str, version: str) -> JSONFlow: async def run_json_flow( - json_flow: JSONFlow, run_input: Optional[Any] = None, session_id: Optional[str] = None + json_flow: JSONFlow, run_input: Any | None = None, session_id: str | None = None ) -> dict[str, Any]: graph = Graph.from_payload(json_flow.json) return await run_flow(graph, run_input, session_id) -async def run_flow(graph: Graph, run_input: Optional[Any] = None, session_id: Optional[str] = None) -> dict[str, Any]: +async def run_flow(graph: Graph, run_input: Any | None = None, session_id: str | None = None) -> dict[str, Any]: graph.prepare() - if run_input: - graph_run_inputs = [InputValueRequest(input_value=run_input, type="chat")] - else: - graph_run_inputs = [] + graph_run_inputs = [InputValueRequest(input_value=run_input, type="chat")] if run_input else [] flow_id = str(uuid.uuid4()) @@ -137,23 +134,24 @@ class ComponentInputHandle: async def run_single_component( clazz: type, - inputs: dict = None, - run_input: Optional[Any] = None, - session_id: Optional[str] = None, - input_type: Optional[str] = "chat", + inputs: dict | None = None, + run_input: Any | None = None, + session_id: str | None = None, + input_type: str | None = "chat", ) -> dict[str, Any]: user_id = str(uuid.uuid4()) flow_id = str(uuid.uuid4()) graph = Graph(user_id=user_id, flow_id=flow_id) - def _add_component(clazz: type, inputs: Optional[dict] = None) -> str: + def _add_component(clazz: type, inputs: dict | None = None) -> str: raw_inputs = {} if inputs: for key, value in inputs.items(): if not isinstance(value, ComponentInputHandle): raw_inputs[key] = value if isinstance(value, Component): - raise ValueError("Component inputs must be wrapped in ComponentInputHandle") + msg = "Component inputs must be wrapped in ComponentInputHandle" + raise ValueError(msg) component = clazz(**raw_inputs, _user_id=user_id) component_id = graph.add_component(component) if inputs: @@ -165,10 +163,7 @@ def _add_component(clazz: type, inputs: Optional[dict] = None) -> str: component_id = _add_component(clazz, inputs) graph.prepare() - if run_input: - graph_run_inputs = [InputValueRequest(input_value=run_input, type=input_type)] - else: - graph_run_inputs = [] + graph_run_inputs = [InputValueRequest(input_value=run_input, type=input_type)] if run_input else [] _, _ = await run_graph_internal( graph, flow_id, session_id=session_id, inputs=graph_run_inputs, outputs=[component_id] diff --git a/src/backend/tests/locust/locustfile.py b/src/backend/tests/locust/locustfile.py index a48e813887cf..da0f87a6e4b8 100644 --- a/src/backend/tests/locust/locustfile.py +++ b/src/backend/tests/locust/locustfile.py @@ -11,8 +11,8 @@ class NameTest(FastHttpUser): wait_time = between(1, 5) - with open("names.txt", "r") as file: - names = [line.strip() for line in file.readlines()] + with open("names.txt", encoding="utf-8") as file: + names = [line.strip() for line in file] headers: dict = {} @@ -28,8 +28,9 @@ def poll_task(self, task_id, sleep_time=1): print(f"Poll Response: {response.js}") if status == "SUCCESS": return response.js.get("result") - elif status in ["FAILURE", "REVOKED"]: - raise ValueError(f"Task failed with status: {status}") + if status in {"FAILURE", "REVOKED"}: + msg = f"Task failed with status: {status}" + raise ValueError(msg) time.sleep(sleep_time) def process(self, name, flow_id, payload): @@ -45,7 +46,8 @@ def process(self, name, flow_id, payload): print(response.js) if response.status_code != 200: response.failure("Process call failed") - raise ValueError("Process call failed") + msg = "Process call failed" + raise ValueError(msg) task_id = response.js.get("id") session_id = response.js.get("session_id") assert task_id, "Inner Task ID not found" @@ -88,7 +90,7 @@ def on_start(self): print("Logged in") with open( Path(__file__).parent.parent / "data" / "BasicChatwithPromptandHistory.json", - "r", + encoding="utf-8", ) as f: json_flow = f.read() flow = orjson.loads(json_flow) diff --git a/src/backend/tests/unit/api/test_api_utils.py b/src/backend/tests/unit/api/test_api_utils.py index b21bc25d2b55..d992f10fe591 100644 --- a/src/backend/tests/unit/api/test_api_utils.py +++ b/src/backend/tests/unit/api/test_api_utils.py @@ -1,5 +1,6 @@ -from langflow.api.utils import get_suggestion_message from unittest.mock import patch + +from langflow.api.utils import get_suggestion_message from langflow.services.database.models.flow.utils import get_outdated_components from langflow.utils.version import get_version_info diff --git a/src/backend/tests/unit/api/v1/test_variable.py b/src/backend/tests/unit/api/v1/test_variable.py index 27a33e91ee1c..a8df7204d717 100644 --- a/src/backend/tests/unit/api/v1/test_variable.py +++ b/src/backend/tests/unit/api/v1/test_variable.py @@ -20,11 +20,11 @@ async def test_create_variable(client: AsyncClient, body, active_user, logged_in response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers) result = response.json() - assert status.HTTP_201_CREATED == response.status_code + assert response.status_code == status.HTTP_201_CREATED assert body["name"] == result["name"] assert body["type"] == result["type"] assert body["default_fields"] == result["default_fields"] - assert "id" in result.keys() + assert "id" in result assert body["value"] != result["value"] @@ -34,7 +34,7 @@ async def test_create_variable__variable_name_already_exists(client: AsyncClient response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers) result = response.json() - assert status.HTTP_400_BAD_REQUEST == response.status_code + assert response.status_code == status.HTTP_400_BAD_REQUEST assert "Variable name already exists" in result["detail"] @@ -47,7 +47,7 @@ async def test_create_variable__variable_name_and_value_cannot_be_empty( response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers) result = response.json() - assert status.HTTP_400_BAD_REQUEST == response.status_code + assert response.status_code == status.HTTP_400_BAD_REQUEST assert "Variable name and value cannot be empty" in result["detail"] @@ -59,7 +59,7 @@ async def test_create_variable__variable_name_cannot_be_empty( response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers) result = response.json() - assert status.HTTP_400_BAD_REQUEST == response.status_code + assert response.status_code == status.HTTP_400_BAD_REQUEST assert "Variable name cannot be empty" in result["detail"] @@ -71,7 +71,7 @@ async def test_create_variable__variable_value_cannot_be_empty( response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers) result = response.json() - assert status.HTTP_400_BAD_REQUEST == response.status_code + assert response.status_code == status.HTTP_400_BAD_REQUEST assert "Variable value cannot be empty" in result["detail"] @@ -84,7 +84,7 @@ async def test_create_variable__HTTPException(client: AsyncClient, body, active_ response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers) result = response.json() - assert status.HTTP_418_IM_A_TEAPOT == response.status_code + assert response.status_code == status.HTTP_418_IM_A_TEAPOT assert generic_message in result["detail"] @@ -96,7 +96,7 @@ async def test_create_variable__Exception(client: AsyncClient, body, active_user response = await client.post("api/v1/variables/", json=body, headers=logged_in_headers) result = response.json() - assert status.HTTP_500_INTERNAL_SERVER_ERROR == response.status_code + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert generic_message in result["detail"] @@ -109,7 +109,7 @@ async def test_read_variables(client: AsyncClient, body, active_user, logged_in_ response = await client.get("api/v1/variables/", headers=logged_in_headers) result = response.json() - assert status.HTTP_200_OK == response.status_code + assert response.status_code == status.HTTP_200_OK assert all(name in [r["name"] for r in result] for name in names) @@ -122,22 +122,21 @@ async def test_read_variables__empty(client: AsyncClient, active_user, logged_in response = await client.get("api/v1/variables/", headers=logged_in_headers) result = response.json() - assert status.HTTP_200_OK == response.status_code - assert [] == result + assert response.status_code == status.HTTP_200_OK + assert result == [] async def test_read_variables__(client: AsyncClient, active_user, logged_in_headers): generic_message = "Generic error message" - with pytest.raises(Exception) as exc: - with mock.patch("sqlmodel.Session.exec") as m: - m.side_effect = Exception(generic_message) + with pytest.raises(Exception) as exc, mock.patch("sqlmodel.Session.exec") as m: + m.side_effect = Exception(generic_message) - response = await client.get("api/v1/variables/", headers=logged_in_headers) - result = response.json() + response = await client.get("api/v1/variables/", headers=logged_in_headers) + result = response.json() - assert status.HTTP_500_INTERNAL_SERVER_ERROR == response.status_code - assert generic_message in result["detail"] + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert generic_message in result["detail"] assert generic_message in str(exc.value) @@ -154,7 +153,7 @@ async def test_update_variable(client: AsyncClient, body, active_user, logged_in response = await client.patch(f"api/v1/variables/{saved.get('id')}", json=body, headers=logged_in_headers) result = response.json() - assert status.HTTP_200_OK == response.status_code + assert response.status_code == status.HTTP_200_OK assert saved["id"] == result["id"] assert saved["name"] != result["name"] assert saved["default_fields"] != result["default_fields"] @@ -167,7 +166,7 @@ async def test_update_variable__Exception(client: AsyncClient, body, active_user response = await client.patch(f"api/v1/variables/{wrong_id}", json=body, headers=logged_in_headers) result = response.json() - assert status.HTTP_404_NOT_FOUND == response.status_code + assert response.status_code == status.HTTP_404_NOT_FOUND assert "Variable not found" in result["detail"] @@ -176,7 +175,7 @@ async def test_delete_variable(client: AsyncClient, body, active_user, logged_in saved = response.json() response = await client.delete(f"api/v1/variables/{saved.get('id')}", headers=logged_in_headers) - assert status.HTTP_204_NO_CONTENT == response.status_code + assert response.status_code == status.HTTP_204_NO_CONTENT async def test_delete_variable__Exception(client: AsyncClient, active_user, logged_in_headers): @@ -184,4 +183,4 @@ async def test_delete_variable__Exception(client: AsyncClient, active_user, logg response = await client.delete(f"api/v1/variables/{wrong_id}", headers=logged_in_headers) - assert status.HTTP_500_INTERNAL_SERVER_ERROR == response.status_code + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR diff --git a/src/backend/tests/unit/base/load/test_load.py b/src/backend/tests/unit/base/load/test_load.py index 964a326fa86c..f59ddfd1606c 100644 --- a/src/backend/tests/unit/base/load/test_load.py +++ b/src/backend/tests/unit/base/load/test_load.py @@ -1,4 +1,5 @@ import inspect + from langflow.load import run_flow_from_json diff --git a/src/backend/tests/unit/base/tools/test_component_toolkit.py b/src/backend/tests/unit/base/tools/test_component_toolkit.py index d471355b8b08..8fd047609a31 100644 --- a/src/backend/tests/unit/base/tools/test_component_toolkit.py +++ b/src/backend/tests/unit/base/tools/test_component_toolkit.py @@ -1,7 +1,6 @@ import os import pytest - from langflow.base.tools.component_tool import ComponentToolkit from langflow.components.agents.ToolCallingAgent import ToolCallingAgentComponent from langflow.components.inputs.ChatInput import ChatInput @@ -81,7 +80,7 @@ def test_component_tool(): } assert component_toolkit.component == chat_input - result = component_tool.invoke(input=dict(input_value="test")) + result = component_tool.invoke(input={"input_value": "test"}) assert isinstance(result, Message) assert result.get_text() == "test" diff --git a/src/backend/tests/unit/components/models/test_ChatOllama_component.py b/src/backend/tests/unit/components/models/test_ChatOllama_component.py index d78c09a3f343..5dd88725df44 100644 --- a/src/backend/tests/unit/components/models/test_ChatOllama_component.py +++ b/src/backend/tests/unit/components/models/test_ChatOllama_component.py @@ -3,7 +3,6 @@ import pytest from langchain_community.chat_models.ollama import ChatOllama - from langflow.components.models.OllamaModel import ChatOllamaComponent diff --git a/src/backend/tests/unit/components/prompts/test_prompt_component.py b/src/backend/tests/unit/components/prompts/test_prompt_component.py index 8f8494013817..ff1a92569fa4 100644 --- a/src/backend/tests/unit/components/prompts/test_prompt_component.py +++ b/src/backend/tests/unit/components/prompts/test_prompt_component.py @@ -1,5 +1,4 @@ import pytest - from langflow.components.prompts.Prompt import PromptComponent # type: ignore diff --git a/src/backend/tests/unit/components/tools/test_python_repl_tool.py b/src/backend/tests/unit/components/tools/test_python_repl_tool.py index c5be4ba85c60..55a1f35ac1f4 100644 --- a/src/backend/tests/unit/components/tools/test_python_repl_tool.py +++ b/src/backend/tests/unit/components/tools/test_python_repl_tool.py @@ -1,5 +1,4 @@ import pytest - from langflow.components.tools.PythonREPLTool import PythonREPLToolComponent from langflow.custom.custom_component.component import Component from langflow.custom.utils import build_custom_component_template diff --git a/src/backend/tests/unit/components/tools/test_yfinance_tool.py b/src/backend/tests/unit/components/tools/test_yfinance_tool.py index 75630a42daec..6dd7258bfc02 100644 --- a/src/backend/tests/unit/components/tools/test_yfinance_tool.py +++ b/src/backend/tests/unit/components/tools/test_yfinance_tool.py @@ -1,5 +1,4 @@ import pytest - from langflow.components.tools.YfinanceTool import YfinanceToolComponent from langflow.custom.custom_component.component import Component from langflow.custom.utils import build_custom_component_template diff --git a/src/backend/tests/unit/custom/component/test_component_to_tool.py b/src/backend/tests/unit/custom/component/test_component_to_tool.py index df8e7cbf75aa..c6400a38ad17 100644 --- a/src/backend/tests/unit/custom/component/test_component_to_tool.py +++ b/src/backend/tests/unit/custom/component/test_component_to_tool.py @@ -1,7 +1,6 @@ from collections.abc import Callable import pytest - from langflow.components.inputs.ChatInput import ChatInput diff --git a/src/backend/tests/unit/custom/custom_component/test_component.py b/src/backend/tests/unit/custom/custom_component/test_component.py index be67c774df77..485b606930cf 100644 --- a/src/backend/tests/unit/custom/custom_component/test_component.py +++ b/src/backend/tests/unit/custom/custom_component/test_component.py @@ -1,5 +1,4 @@ import pytest - from langflow.components.agents.CrewAIAgent import CrewAIAgentComponent from langflow.components.agents.ToolCallingAgent import ToolCallingAgentComponent from langflow.components.helpers.SequentialTask import SequentialTaskComponent diff --git a/src/backend/tests/unit/events/test_event_manager.py b/src/backend/tests/unit/events/test_event_manager.py index 98d19a15eafb..34b20455fb9f 100644 --- a/src/backend/tests/unit/events/test_event_manager.py +++ b/src/backend/tests/unit/events/test_event_manager.py @@ -4,7 +4,6 @@ import uuid import pytest - from langflow.events.event_manager import EventManager from langflow.schema.log import LoggableType diff --git a/src/backend/tests/unit/exceptions/test_api.py b/src/backend/tests/unit/exceptions/test_api.py index 542986f59272..38cf5f57f36a 100644 --- a/src/backend/tests/unit/exceptions/test_api.py +++ b/src/backend/tests/unit/exceptions/test_api.py @@ -1,4 +1,5 @@ -from unittest.mock import patch, Mock +from unittest.mock import Mock, patch + from langflow.services.database.models.flow.model import Flow diff --git a/src/backend/tests/unit/graph/edge/test_edge_base.py b/src/backend/tests/unit/graph/edge/test_edge_base.py index 62676c4fbc94..992b2f241cd6 100644 --- a/src/backend/tests/unit/graph/edge/test_edge_base.py +++ b/src/backend/tests/unit/graph/edge/test_edge_base.py @@ -1,5 +1,4 @@ import pytest - from langflow.components.inputs.ChatInput import ChatInput from langflow.components.models.OpenAIModel import OpenAIModelComponent from langflow.components.outputs.ChatOutput import ChatOutput diff --git a/src/backend/tests/unit/graph/graph/state/test_state_model.py b/src/backend/tests/unit/graph/graph/state/test_state_model.py index 1c1e800cab16..7a563cb50902 100644 --- a/src/backend/tests/unit/graph/graph/state/test_state_model.py +++ b/src/backend/tests/unit/graph/graph/state/test_state_model.py @@ -1,12 +1,11 @@ import pytest -from pydantic import Field - from langflow.components.inputs import ChatInput from langflow.components.outputs.ChatOutput import ChatOutput from langflow.graph.graph.base import Graph from langflow.graph.graph.constants import Finish from langflow.graph.state.model import create_state_model from langflow.template.field.base import UNDEFINED +from pydantic import Field @pytest.fixture diff --git a/src/backend/tests/unit/graph/graph/test_base.py b/src/backend/tests/unit/graph/graph/test_base.py index 8a7e328149e8..98f6701b9693 100644 --- a/src/backend/tests/unit/graph/graph/test_base.py +++ b/src/backend/tests/unit/graph/graph/test_base.py @@ -2,8 +2,6 @@ from collections import deque import pytest -from pytest import LogCaptureFixture - from langflow.components.agents.ToolCallingAgent import ToolCallingAgentComponent from langflow.components.inputs.ChatInput import ChatInput from langflow.components.outputs.ChatOutput import ChatOutput @@ -11,6 +9,7 @@ from langflow.components.tools.YfinanceTool import YfinanceToolComponent from langflow.graph.graph.base import Graph from langflow.graph.graph.constants import Finish +from pytest import LogCaptureFixture @pytest.fixture diff --git a/src/backend/tests/unit/graph/graph/test_callback_graph.py b/src/backend/tests/unit/graph/graph/test_callback_graph.py index d70c7cbe48af..046c70b2260b 100644 --- a/src/backend/tests/unit/graph/graph/test_callback_graph.py +++ b/src/backend/tests/unit/graph/graph/test_callback_graph.py @@ -1,7 +1,6 @@ import asyncio import pytest - from langflow.components.outputs.ChatOutput import ChatOutput from langflow.custom.custom_component.component import Component from langflow.events.event_manager import EventManager diff --git a/src/backend/tests/unit/graph/graph/test_cycles.py b/src/backend/tests/unit/graph/graph/test_cycles.py index 573ec6f0356f..06a97289c32d 100644 --- a/src/backend/tests/unit/graph/graph/test_cycles.py +++ b/src/backend/tests/unit/graph/graph/test_cycles.py @@ -1,7 +1,6 @@ import os import pytest - from langflow.components.inputs.ChatInput import ChatInput from langflow.components.models.OpenAIModel import OpenAIModelComponent from langflow.components.outputs.ChatOutput import ChatOutput @@ -208,5 +207,3 @@ def test_updated_graph_with_prompts(): # Extract the vertex IDs for analysis results_ids = [result.vertex.id for result in results if hasattr(result, "vertex")] assert "chat_output_1" in results_ids, f"Expected outputs not in results: {results_ids}" - - print(f"Execution completed with results: {results_ids}") diff --git a/src/backend/tests/unit/graph/graph/test_graph_state_model.py b/src/backend/tests/unit/graph/graph/test_graph_state_model.py index afdb23e15901..7ba09f684bd4 100644 --- a/src/backend/tests/unit/graph/graph/test_graph_state_model.py +++ b/src/backend/tests/unit/graph/graph/test_graph_state_model.py @@ -1,6 +1,6 @@ -import pytest -from pydantic import BaseModel +from typing import TYPE_CHECKING +import pytest from langflow.components.helpers.Memory import MemoryComponent from langflow.components.inputs.ChatInput import ChatInput from langflow.components.models.OpenAIModel import OpenAIModelComponent @@ -10,6 +10,9 @@ from langflow.graph.graph.constants import Finish from langflow.graph.graph.state_model import create_state_model_from_graph +if TYPE_CHECKING: + from pydantic import BaseModel + @pytest.fixture def client(): diff --git a/src/backend/tests/unit/graph/graph/test_runnable_vertices_manager.py b/src/backend/tests/unit/graph/graph/test_runnable_vertices_manager.py index f42b2e9493c0..f3cb55f9921c 100644 --- a/src/backend/tests/unit/graph/graph/test_runnable_vertices_manager.py +++ b/src/backend/tests/unit/graph/graph/test_runnable_vertices_manager.py @@ -1,10 +1,12 @@ import pickle -from collections import defaultdict +from typing import TYPE_CHECKING import pytest - from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager +if TYPE_CHECKING: + from collections import defaultdict + @pytest.fixture def client(): @@ -28,7 +30,7 @@ def data(): def test_to_dict(data): result = RunnableVerticesManager.from_dict(data).to_dict() - assert all(key in result.keys() for key in data.keys()) + assert all(key in result for key in data) def test_from_dict(data): @@ -163,8 +165,8 @@ def test_build_run_map(data): manager.build_run_map(predecessor_map, vertices_to_run) - assert all(v in manager.run_map.keys() for v in ["Z", "X", "Y"]) - assert "W" not in manager.run_map.keys() + assert all(v in manager.run_map for v in ["Z", "X", "Y"]) + assert "W" not in manager.run_map def test_update_vertex_run_state(data): diff --git a/src/backend/tests/unit/graph/graph/test_utils.py b/src/backend/tests/unit/graph/graph/test_utils.py index 54c4788029ff..ae1acca5d0e0 100644 --- a/src/backend/tests/unit/graph/graph/test_utils.py +++ b/src/backend/tests/unit/graph/graph/test_utils.py @@ -1,7 +1,6 @@ import copy import pytest - from langflow.graph.graph import utils diff --git a/src/backend/tests/unit/graph/test_graph.py b/src/backend/tests/unit/graph/test_graph.py index 0a96a748c57c..9bb4e253bddb 100644 --- a/src/backend/tests/unit/graph/test_graph.py +++ b/src/backend/tests/unit/graph/test_graph.py @@ -2,7 +2,6 @@ import json import pytest - from langflow.graph import Graph from langflow.graph.graph.utils import ( find_last_node, @@ -64,7 +63,7 @@ def sample_nodes(): def get_node_by_type(graph, node_type: type[Vertex]) -> Vertex | None: - """Get a node by type""" + """Get a node by type.""" return next((node for node in graph.vertices if isinstance(node, node_type)), None) @@ -136,10 +135,10 @@ def test_process_flow_one_group(one_grouped_chat_json_flow): node_data = group_node["data"]["node"] assert node_data.get("flow") is not None template_data = node_data["template"] - assert any("openai_api_key" in key for key in template_data.keys()) + assert any("openai_api_key" in key for key in template_data) # Get the openai_api_key dict openai_api_key = next( - (template_data[key] for key in template_data.keys() if "openai_api_key" in key), + (template_data[key] for key in template_data if "openai_api_key" in key), None, ) assert openai_api_key is not None diff --git a/src/backend/tests/unit/initial_setup/starter_projects/test_memory_chatbot.py b/src/backend/tests/unit/initial_setup/starter_projects/test_memory_chatbot.py index 5ccfa421bb50..59016abd6f85 100644 --- a/src/backend/tests/unit/initial_setup/starter_projects/test_memory_chatbot.py +++ b/src/backend/tests/unit/initial_setup/starter_projects/test_memory_chatbot.py @@ -1,7 +1,8 @@ +import operator from collections import deque +from typing import TYPE_CHECKING import pytest - from langflow.components.helpers.Memory import MemoryComponent from langflow.components.inputs.ChatInput import ChatInput from langflow.components.models.OpenAIModel import OpenAIModelComponent @@ -9,7 +10,9 @@ from langflow.components.prompts.Prompt import PromptComponent from langflow.graph import Graph from langflow.graph.graph.constants import Finish -from langflow.graph.graph.schema import GraphDump + +if TYPE_CHECKING: + from langflow.graph.graph.schema import GraphDump @pytest.fixture @@ -100,7 +103,7 @@ def test_memory_chatbot_dump_components_and_edges(memory_chatbot_graph: Graph): edges = data_dict["edges"] # sort the nodes by id - nodes = sorted(nodes, key=lambda x: x["id"]) + nodes = sorted(nodes, key=operator.itemgetter("id")) # Check each node assert nodes[0]["data"]["type"] == "ChatInput" diff --git a/src/backend/tests/unit/initial_setup/starter_projects/test_vector_store_rag.py b/src/backend/tests/unit/initial_setup/starter_projects/test_vector_store_rag.py index d2d948970838..1ad6b3d09be5 100644 --- a/src/backend/tests/unit/initial_setup/starter_projects/test_vector_store_rag.py +++ b/src/backend/tests/unit/initial_setup/starter_projects/test_vector_store_rag.py @@ -1,8 +1,8 @@ import copy +import operator from textwrap import dedent import pytest - from langflow.components.data.File import FileComponent from langflow.components.embeddings.OpenAIEmbeddings import OpenAIEmbeddingsComponent from langflow.components.helpers.ParseData import ParseDataComponent @@ -45,8 +45,7 @@ def ingestion_graph(): vector_store.set_on_output(name="base_retriever", value="mock_retriever", cache=True) vector_store.set_on_output(name="search_results", value=[Data(text="This is a test file.")], cache=True) - ingestion_graph = Graph(file_component, vector_store) - return ingestion_graph + return Graph(file_component, vector_store) @pytest.fixture @@ -94,8 +93,7 @@ def rag_graph(): chat_output = ChatOutput(_id="chatoutput-123") chat_output.set(input_value=openai_component.text_response) - graph = Graph(start=chat_input, end=chat_output) - return graph + return Graph(start=chat_input, end=chat_output) def test_vector_store_rag(ingestion_graph, rag_graph): @@ -116,7 +114,7 @@ def test_vector_store_rag(ingestion_graph, rag_graph): "rag-vector-store-123", "openai-embeddings-124", ] - for ids, graph, len_results in zip([ingestion_ids, rag_ids], [ingestion_graph, rag_graph], [5, 8]): + for ids, graph, len_results in [(ingestion_ids, ingestion_graph, 5), (rag_ids, rag_graph, 8)]: results = [] for result in graph.start(): results.append(result) @@ -139,7 +137,7 @@ def test_vector_store_rag_dump_components_and_edges(ingestion_graph, rag_graph): ingestion_edges = ingestion_data["edges"] # Sort nodes by id to check components - ingestion_nodes = sorted(ingestion_nodes, key=lambda x: x["id"]) + ingestion_nodes = sorted(ingestion_nodes, key=operator.itemgetter("id")) # Check components in the ingestion graph assert ingestion_nodes[0]["data"]["type"] == "File" @@ -177,7 +175,7 @@ def test_vector_store_rag_dump_components_and_edges(ingestion_graph, rag_graph): rag_edges = rag_data["edges"] # Sort nodes by id to check components - rag_nodes = sorted(rag_nodes, key=lambda x: x["id"]) + rag_nodes = sorted(rag_nodes, key=operator.itemgetter("id")) # Check components in the RAG graph assert rag_nodes[0]["data"]["type"] == "ChatInput" @@ -240,7 +238,7 @@ def test_vector_store_rag_add(ingestion_graph: Graph, rag_graph: Graph): combined_edges = combined_data["edges"] # Sort nodes by id to check components - combined_nodes = sorted(combined_nodes, key=lambda x: x["id"]) + combined_nodes = sorted(combined_nodes, key=operator.itemgetter("id")) # Expected components in the combined graph (both ingestion and RAG nodes) expected_nodes = sorted( @@ -257,10 +255,10 @@ def test_vector_store_rag_add(ingestion_graph: Graph, rag_graph: Graph): {"id": "prompt-123", "type": "Prompt"}, {"id": "rag-vector-store-123", "type": "AstraDB"}, ], - key=lambda x: x["id"], + key=operator.itemgetter("id"), ) - for expected_node, combined_node in zip(expected_nodes, combined_nodes): + for expected_node, combined_node in zip(expected_nodes, combined_nodes, strict=True): assert combined_node["data"]["type"] == expected_node["type"] assert combined_node["id"] == expected_node["id"] diff --git a/src/backend/tests/unit/inputs/test_inputs.py b/src/backend/tests/unit/inputs/test_inputs.py index f8b7c7796c31..5e734bea5c2b 100644 --- a/src/backend/tests/unit/inputs/test_inputs.py +++ b/src/backend/tests/unit/inputs/test_inputs.py @@ -1,6 +1,4 @@ import pytest -from pydantic import ValidationError - from langflow.inputs.inputs import ( BoolInput, CodeInput, @@ -24,6 +22,7 @@ ) from langflow.inputs.utils import instantiate_input from langflow.schema.message import Message +from pydantic import ValidationError @pytest.fixture diff --git a/src/backend/tests/unit/io/test_io_schema.py b/src/backend/tests/unit/io/test_io_schema.py index 97caead0ba6c..5e970d69929c 100644 --- a/src/backend/tests/unit/io/test_io_schema.py +++ b/src/backend/tests/unit/io/test_io_schema.py @@ -1,10 +1,11 @@ -from typing import Literal +from typing import TYPE_CHECKING, Literal import pytest -from pydantic.fields import FieldInfo - from langflow.components.inputs.ChatInput import ChatInput +if TYPE_CHECKING: + from pydantic.fields import FieldInfo + @pytest.fixture def client(): diff --git a/src/backend/tests/unit/io/test_table_schema.py b/src/backend/tests/unit/io/test_table_schema.py index cefa5f82f3c0..b8de676f19aa 100644 --- a/src/backend/tests/unit/io/test_table_schema.py +++ b/src/backend/tests/unit/io/test_table_schema.py @@ -1,7 +1,6 @@ # Generated by qodo Gen import pytest - from langflow.schema.table import Column, FormatterType diff --git a/src/backend/tests/unit/schema/test_schema_message.py b/src/backend/tests/unit/schema/test_schema_message.py index 846ca4d61b9f..508ae9c9de17 100644 --- a/src/backend/tests/unit/schema/test_schema_message.py +++ b/src/backend/tests/unit/schema/test_schema_message.py @@ -1,6 +1,5 @@ import pytest from langchain_core.prompts.chat import ChatPromptTemplate - from langflow.schema.message import Message diff --git a/src/backend/tests/unit/services/variable/test_service.py b/src/backend/tests/unit/services/variable/test_service.py index cb3a7824c3ce..6bf80ac5143c 100644 --- a/src/backend/tests/unit/services/variable/test_service.py +++ b/src/backend/tests/unit/services/variable/test_service.py @@ -3,12 +3,11 @@ from uuid import uuid4 import pytest -from sqlmodel import Session, SQLModel, create_engine - from langflow.services.database.models.variable.model import VariableUpdate from langflow.services.deps import get_settings_service from langflow.services.variable.constants import CREDENTIAL_TYPE, GENERIC_TYPE from langflow.services.variable.service import DatabaseVariableService +from sqlmodel import Session, SQLModel, create_engine @pytest.fixture diff --git a/src/backend/tests/unit/test_api_key.py b/src/backend/tests/unit/test_api_key.py index 5deaa34367eb..d99eb11fee96 100644 --- a/src/backend/tests/unit/test_api_key.py +++ b/src/backend/tests/unit/test_api_key.py @@ -1,6 +1,5 @@ import pytest from httpx import AsyncClient - from langflow.services.database.models.api_key import ApiKeyCreate @@ -29,7 +28,8 @@ async def test_create_api_key(client: AsyncClient, logged_in_headers): response = await client.post("api/v1/api_key/", json={"name": api_key_name}, headers=logged_in_headers) assert response.status_code == 200 data = response.json() - assert "name" in data and data["name"] == api_key_name + assert "name" in data + assert data["name"] == api_key_name assert "api_key" in data assert "**" not in data["api_key"] diff --git a/src/backend/tests/unit/test_chat_endpoint.py b/src/backend/tests/unit/test_chat_endpoint.py index 5140cc3f0a47..9d267f266818 100644 --- a/src/backend/tests/unit/test_chat_endpoint.py +++ b/src/backend/tests/unit/test_chat_endpoint.py @@ -1,10 +1,9 @@ import json from uuid import UUID -from orjson import orjson - from langflow.memory import get_messages from langflow.services.database.models.flow import FlowCreate, FlowUpdate +from orjson import orjson async def test_build_flow(client, json_memory_chatbot_no_llm, logged_in_headers): @@ -82,7 +81,8 @@ async def consume_and_assert_stream(r): elif count == 5: assert parsed["event"] == "end" else: - raise ValueError(f"Unexpected line: {line}") + msg = f"Unexpected line: {line}" + raise ValueError(msg) count += 1 @@ -92,5 +92,4 @@ async def _create_flow(client, json_memory_chatbot_no_llm, logged_in_headers): vector_store = FlowCreate(name="Flow", description="description", data=data, endpoint_name="f") response = await client.post("api/v1/flows/", json=vector_store.model_dump(), headers=logged_in_headers) response.raise_for_status() - flow_id = response.json()["id"] - return flow_id + return response.json()["id"] diff --git a/src/backend/tests/unit/test_cli.py b/src/backend/tests/unit/test_cli.py index a59b169057f2..10bdac0dde69 100644 --- a/src/backend/tests/unit/test_cli.py +++ b/src/backend/tests/unit/test_cli.py @@ -1,5 +1,4 @@ import pytest - from langflow.__main__ import app from langflow.services import deps diff --git a/src/backend/tests/unit/test_custom_component.py b/src/backend/tests/unit/test_custom_component.py index 63103a379c07..6e9a13e395b1 100644 --- a/src/backend/tests/unit/test_custom_component.py +++ b/src/backend/tests/unit/test_custom_component.py @@ -4,7 +4,6 @@ import pytest from langchain_core.documents import Document - from langflow.custom import Component, CustomComponent from langflow.custom.code_parser.code_parser import CodeParser, CodeSyntaxError from langflow.custom.custom_component.base_component import BaseComponent, ComponentCodeNullError @@ -18,7 +17,7 @@ def client(): @pytest.fixture def code_component_with_multiple_outputs(): - with open("src/backend/tests/data/component_multiple_outputs.py") as f: + with open("src/backend/tests/data/component_multiple_outputs.py", encoding="utf-8") as f: code = f.read() return Component(_code=code) @@ -44,25 +43,20 @@ def build(self, url: str, llm: BaseLanguageModel) -> Document: def test_code_parser_init(): - """ - Test the initialization of the CodeParser class. - """ + """Test the initialization of the CodeParser class.""" parser = CodeParser(code_default) assert parser.code == code_default def test_code_parser_get_tree(): - """ - Test the __get_tree method of the CodeParser class. - """ + """Test the __get_tree method of the CodeParser class.""" parser = CodeParser(code_default) tree = parser.get_tree() assert isinstance(tree, ast.AST) def test_code_parser_syntax_error(): - """ - Test the __get_tree method raises the + """Test the __get_tree method raises the CodeSyntaxError when given incorrect syntax. """ code_syntax_error = "zzz import os" @@ -73,26 +67,21 @@ def test_code_parser_syntax_error(): def test_component_init(): - """ - Test the initialization of the Component class. - """ + """Test the initialization of the Component class.""" component = BaseComponent(_code=code_default, _function_entrypoint_name="build") assert component._code == code_default assert component._function_entrypoint_name == "build" def test_component_get_code_tree(): - """ - Test the get_code_tree method of the Component class. - """ + """Test the get_code_tree method of the Component class.""" component = BaseComponent(_code=code_default, _function_entrypoint_name="build") tree = component.get_code_tree(component._code) assert "imports" in tree def test_component_code_null_error(): - """ - Test the get_function method raises the + """Test the get_function method raises the ComponentCodeNullError when the code is empty. """ component = BaseComponent(_code="", _function_entrypoint_name="") @@ -101,9 +90,7 @@ def test_component_code_null_error(): def test_custom_component_init(): - """ - Test the initialization of the CustomComponent class. - """ + """Test the initialization of the CustomComponent class.""" function_entrypoint_name = "build" custom_component = CustomComponent(_code=code_default, _function_entrypoint_name=function_entrypoint_name) @@ -112,26 +99,21 @@ def test_custom_component_init(): def test_custom_component_build_template_config(): - """ - Test the build_template_config property of the CustomComponent class. - """ + """Test the build_template_config property of the CustomComponent class.""" custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build") config = custom_component.build_template_config() assert isinstance(config, dict) def test_custom_component_get_function(): - """ - Test the get_function property of the CustomComponent class. - """ + """Test the get_function property of the CustomComponent class.""" custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build") my_function = custom_component.get_function() assert isinstance(my_function, types.FunctionType) def test_code_parser_parse_imports_import(): - """ - Test the parse_imports method of the CodeParser + """Test the parse_imports method of the CodeParser class with an import statement. """ parser = CodeParser(code_default) @@ -143,8 +125,7 @@ class with an import statement. def test_code_parser_parse_imports_importfrom(): - """ - Test the parse_imports method of the CodeParser + """Test the parse_imports method of the CodeParser class with an import from statement. """ parser = CodeParser("from os import path") @@ -156,9 +137,7 @@ class with an import from statement. def test_code_parser_parse_functions(): - """ - Test the parse_functions method of the CodeParser class. - """ + """Test the parse_functions method of the CodeParser class.""" parser = CodeParser("def test(): pass") tree = parser.get_tree() for node in ast.walk(tree): @@ -169,9 +148,7 @@ def test_code_parser_parse_functions(): def test_code_parser_parse_classes(): - """ - Test the parse_classes method of the CodeParser class. - """ + """Test the parse_classes method of the CodeParser class.""" parser = CodeParser("from langflow.custom import Component\n\nclass Test(Component): pass") tree = parser.get_tree() for node in ast.walk(tree): @@ -182,9 +159,7 @@ def test_code_parser_parse_classes(): def test_code_parser_parse_classes_raises(): - """ - Test the parse_classes method of the CodeParser class. - """ + """Test the parse_classes method of the CodeParser class.""" parser = CodeParser("class Test: pass") tree = parser.get_tree() with pytest.raises(TypeError): @@ -194,9 +169,7 @@ def test_code_parser_parse_classes_raises(): def test_code_parser_parse_global_vars(): - """ - Test the parse_global_vars method of the CodeParser class. - """ + """Test the parse_global_vars method of the CodeParser class.""" parser = CodeParser("x = 1") tree = parser.get_tree() for node in ast.walk(tree): @@ -207,8 +180,7 @@ def test_code_parser_parse_global_vars(): def test_component_get_function_valid(): - """ - Test the get_function method of the Component + """Test the get_function method of the Component class with valid code and function_entrypoint_name. """ component = BaseComponent(_code="def build(): pass", _function_entrypoint_name="build") @@ -217,8 +189,7 @@ class with valid code and function_entrypoint_name. def test_custom_component_get_function_entrypoint_args(): - """ - Test the get_function_entrypoint_args + """Test the get_function_entrypoint_args property of the CustomComponent class. """ custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build") @@ -230,28 +201,23 @@ def test_custom_component_get_function_entrypoint_args(): def test_custom_component_get_function_entrypoint_return_type(): - """ - Test the get_function_entrypoint_return_type + """Test the get_function_entrypoint_return_type property of the CustomComponent class. """ - custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build") return_type = custom_component.get_function_entrypoint_return_type assert return_type == [Document] def test_custom_component_get_main_class_name(): - """ - Test the get_main_class_name property of the CustomComponent class. - """ + """Test the get_main_class_name property of the CustomComponent class.""" custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build") class_name = custom_component.get_main_class_name assert class_name == "YourComponent" def test_custom_component_get_function_valid(): - """ - Test the get_function property of the CustomComponent + """Test the get_function property of the CustomComponent class with valid code and function_entrypoint_name. """ custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build") @@ -260,9 +226,7 @@ class with valid code and function_entrypoint_name. def test_code_parser_parse_arg_no_annotation(): - """ - Test the parse_arg method of the CodeParser class without an annotation. - """ + """Test the parse_arg method of the CodeParser class without an annotation.""" parser = CodeParser("") arg = ast.arg(arg="x", annotation=None) result = parser.parse_arg(arg, None) @@ -271,9 +235,7 @@ def test_code_parser_parse_arg_no_annotation(): def test_code_parser_parse_arg_with_annotation(): - """ - Test the parse_arg method of the CodeParser class with an annotation. - """ + """Test the parse_arg method of the CodeParser class with an annotation.""" parser = CodeParser("") arg = ast.arg(arg="x", annotation=ast.Name(id="int", ctx=ast.Load())) result = parser.parse_arg(arg, None) @@ -282,8 +244,7 @@ def test_code_parser_parse_arg_with_annotation(): def test_code_parser_parse_callable_details_no_args(): - """ - Test the parse_callable_details method of the + """Test the parse_callable_details method of the CodeParser class with a function with no arguments. """ parser = CodeParser("") @@ -300,9 +261,7 @@ def test_code_parser_parse_callable_details_no_args(): def test_code_parser_parse_assign(): - """ - Test the parse_assign method of the CodeParser class. - """ + """Test the parse_assign method of the CodeParser class.""" parser = CodeParser("") stmt = ast.Assign(targets=[ast.Name(id="x", ctx=ast.Store())], value=ast.Num(n=1)) result = parser.parse_assign(stmt) @@ -311,9 +270,7 @@ def test_code_parser_parse_assign(): def test_code_parser_parse_ann_assign(): - """ - Test the parse_ann_assign method of the CodeParser class. - """ + """Test the parse_ann_assign method of the CodeParser class.""" parser = CodeParser("") stmt = ast.AnnAssign( target=ast.Name(id="x", ctx=ast.Store()), @@ -328,8 +285,7 @@ def test_code_parser_parse_ann_assign(): def test_code_parser_parse_function_def_not_init(): - """ - Test the parse_function_def method of the + """Test the parse_function_def method of the CodeParser class with a function that is not __init__. """ parser = CodeParser("") @@ -346,8 +302,7 @@ def test_code_parser_parse_function_def_not_init(): def test_code_parser_parse_function_def_init(): - """ - Test the parse_function_def method of the + """Test the parse_function_def method of the CodeParser class with an __init__ function. """ parser = CodeParser("") @@ -364,8 +319,7 @@ def test_code_parser_parse_function_def_init(): def test_component_get_code_tree_syntax_error(): - """ - Test the get_code_tree method of the Component class + """Test the get_code_tree method of the Component class raises the CodeSyntaxError when given incorrect syntax. """ component = BaseComponent(_code="import os as", _function_entrypoint_name="build") @@ -374,8 +328,7 @@ def test_component_get_code_tree_syntax_error(): def test_custom_component_class_template_validation_no_code(): - """ - Test the _class_template_validation method of the CustomComponent class + """Test the _class_template_validation method of the CustomComponent class raises the HTTPException when the code is None. """ custom_component = CustomComponent(_code=None, _function_entrypoint_name="build") @@ -384,8 +337,7 @@ def test_custom_component_class_template_validation_no_code(): def test_custom_component_get_code_tree_syntax_error(): - """ - Test the get_code_tree method of the CustomComponent class + """Test the get_code_tree method of the CustomComponent class raises the CodeSyntaxError when given incorrect syntax. """ custom_component = CustomComponent(_code="import os as", _function_entrypoint_name="build") @@ -394,8 +346,7 @@ def test_custom_component_get_code_tree_syntax_error(): def test_custom_component_get_function_entrypoint_args_no_args(): - """ - Test the get_function_entrypoint_args property of + """Test the get_function_entrypoint_args property of the CustomComponent class with a build method with no arguments. """ my_code = """ @@ -410,8 +361,7 @@ def build(): def test_custom_component_get_function_entrypoint_return_type_no_return_type(): - """ - Test the get_function_entrypoint_return_type property of the + """Test the get_function_entrypoint_return_type property of the CustomComponent class with a build method with no return type. """ my_code = """ @@ -426,8 +376,7 @@ def build(): def test_custom_component_get_main_class_name_no_main_class(): - """ - Test the get_main_class_name property of the + """Test the get_main_class_name property of the CustomComponent class when there is no main class. """ my_code = """ @@ -440,8 +389,7 @@ def build(): def test_custom_component_build_not_implemented(): - """ - Test the build method of the CustomComponent + """Test the build method of the CustomComponent class raises the NotImplementedError. """ custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build") @@ -458,7 +406,7 @@ def test_build_config_no_code(): @pytest.fixture def component(): - yield CustomComponent( + return CustomComponent( field_config={ "fields": { "llm": {"type": "str"}, diff --git a/src/backend/tests/unit/test_custom_component_with_client.py b/src/backend/tests/unit/test_custom_component_with_client.py index 65b1b873c93f..a88530a0aab4 100644 --- a/src/backend/tests/unit/test_custom_component_with_client.py +++ b/src/backend/tests/unit/test_custom_component_with_client.py @@ -1,5 +1,4 @@ import pytest - from langflow.custom import Component from langflow.custom.custom_component.custom_component import CustomComponent from langflow.custom.utils import build_custom_component_template @@ -9,7 +8,7 @@ @pytest.fixture def code_component_with_multiple_outputs(): - with open("src/backend/tests/data/component_multiple_outputs.py") as f: + with open("src/backend/tests/data/component_multiple_outputs.py", encoding="utf-8") as f: code = f.read() return Component(_code=code) diff --git a/src/backend/tests/unit/test_data_class.py b/src/backend/tests/unit/test_data_class.py index 6e7374f8b938..72b23139dcad 100644 --- a/src/backend/tests/unit/test_data_class.py +++ b/src/backend/tests/unit/test_data_class.py @@ -1,6 +1,5 @@ import pytest from langchain_core.documents import Document - from langflow.schema import Data diff --git a/src/backend/tests/unit/test_data_components.py b/src/backend/tests/unit/test_data_components.py index 9906fe63002b..d5c82ec78052 100644 --- a/src/backend/tests/unit/test_data_components.py +++ b/src/backend/tests/unit/test_data_components.py @@ -1,13 +1,12 @@ import os import tempfile from pathlib import Path -from unittest.mock import Mock, patch, ANY +from unittest.mock import ANY, Mock, patch import httpx import pytest import respx from httpx import Response - from langflow.components import data @@ -170,10 +169,10 @@ def test_directory_without_mocks(): directory_component = data.DirectoryComponent() with tempfile.TemporaryDirectory() as temp_dir: - with open(temp_dir + "/test.txt", "w") as f: + with open(temp_dir + "/test.txt", "w", encoding="utf-8") as f: f.write("test") # also add a json file - with open(temp_dir + "/test.json", "w") as f: + with open(temp_dir + "/test.json", "w", encoding="utf-8") as f: f.write('{"test": "test"}') directory_component.set_attributes({"path": str(temp_dir), "use_multithreading": False}) @@ -181,7 +180,7 @@ def test_directory_without_mocks(): assert len(results) == 2 values = ["test", '{"test":"test"}'] assert all(result.text in values for result in results), [ - (len(result.text), len(val)) for result, val in zip(results, values) + (len(result.text), len(val)) for result, val in zip(results, values, strict=True) ] # in ../docs/docs/components there are many mdx files diff --git a/src/backend/tests/unit/test_database.py b/src/backend/tests/unit/test_database.py index 0f99d850b432..36643c6d2695 100644 --- a/src/backend/tests/unit/test_database.py +++ b/src/backend/tests/unit/test_database.py @@ -5,8 +5,6 @@ import orjson import pytest from fastapi.testclient import TestClient -from sqlmodel import Session - from langflow.api.v1.schemas import FlowListCreate, ResultDataResponse from langflow.graph.utils import log_transaction, log_vertex_build from langflow.initial_setup.setup import load_flows_from_directory, load_starter_projects @@ -15,6 +13,7 @@ from langflow.services.database.models.folder.model import FolderCreate from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service +from sqlmodel import Session @pytest.fixture(scope="module") @@ -506,5 +505,5 @@ def test_sqlite_pragmas(): with db_service.with_session() as session: from sqlalchemy import text - assert "wal" == session.exec(text("PRAGMA journal_mode;")).scalar() - assert 1 == session.exec(text("PRAGMA synchronous;")).scalar() + assert session.exec(text("PRAGMA journal_mode;")).scalar() == "wal" + assert session.exec(text("PRAGMA synchronous;")).scalar() == 1 diff --git a/src/backend/tests/unit/test_endpoints.py b/src/backend/tests/unit/test_endpoints.py index f41119e59ef5..5236a823c4c3 100644 --- a/src/backend/tests/unit/test_endpoints.py +++ b/src/backend/tests/unit/test_endpoints.py @@ -4,7 +4,6 @@ import pytest from fastapi import status from httpx import AsyncClient - from langflow.custom.directory_reader.directory_reader import DirectoryReader from langflow.services.deps import get_settings_service @@ -377,7 +376,7 @@ async def test_invalid_prompt(client: AsyncClient): @pytest.mark.parametrize( - "prompt,expected_input_variables", + ("prompt", "expected_input_variables"), [ ("{color} is my favorite color.", ["color"]), ("The weather is {weather} today.", ["weather"]), @@ -442,13 +441,13 @@ async def test_successful_run_no_payload(client, simple_api_test, created_api_ke assert isinstance(outputs_dict.get("outputs"), list) assert len(outputs_dict.get("outputs")) == 1 ids = [output.get("component_id") for output in outputs_dict.get("outputs")] - assert all(["ChatOutput" in _id for _id in ids]) + assert all("ChatOutput" in _id for _id in ids) display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")] - assert all([name in display_names for name in ["Chat Output"]]) + assert all(name in display_names for name in ["Chat Output"]) output_results_has_results = all("results" in output.get("results") for output in outputs_dict.get("outputs")) inner_results = [output.get("results") for output in outputs_dict.get("outputs")] - assert all([result is not None for result in inner_results]), (outputs_dict, output_results_has_results) + assert all(result is not None for result in inner_results), (outputs_dict, output_results_has_results) async def test_successful_run_with_output_type_text(client, simple_api_test, created_api_key): @@ -473,12 +472,12 @@ async def test_successful_run_with_output_type_text(client, simple_api_test, cre assert isinstance(outputs_dict.get("outputs"), list) assert len(outputs_dict.get("outputs")) == 1 ids = [output.get("component_id") for output in outputs_dict.get("outputs")] - assert all(["ChatOutput" in _id for _id in ids]), ids + assert all("ChatOutput" in _id for _id in ids), ids display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")] - assert all([name in display_names for name in ["Chat Output"]]), display_names + assert all(name in display_names for name in ["Chat Output"]), display_names inner_results = [output.get("results") for output in outputs_dict.get("outputs")] expected_keys = ["message"] - assert all([key in result for result in inner_results for key in expected_keys]), outputs_dict + assert all(key in result for result in inner_results for key in expected_keys), outputs_dict async def test_successful_run_with_output_type_any(client, simple_api_test, created_api_key): @@ -504,12 +503,12 @@ async def test_successful_run_with_output_type_any(client, simple_api_test, crea assert isinstance(outputs_dict.get("outputs"), list) assert len(outputs_dict.get("outputs")) == 1 ids = [output.get("component_id") for output in outputs_dict.get("outputs")] - assert all(["ChatOutput" in _id or "TextOutput" in _id for _id in ids]), ids + assert all("ChatOutput" in _id or "TextOutput" in _id for _id in ids), ids display_names = [output.get("component_display_name") for output in outputs_dict.get("outputs")] - assert all([name in display_names for name in ["Chat Output"]]), display_names + assert all(name in display_names for name in ["Chat Output"]), display_names inner_results = [output.get("results") for output in outputs_dict.get("outputs")] expected_keys = ["message"] - assert all([key in result for result in inner_results for key in expected_keys]), outputs_dict + assert all(key in result for result in inner_results for key in expected_keys), outputs_dict async def test_successful_run_with_output_type_debug(client, simple_api_test, created_api_key): @@ -566,7 +565,7 @@ async def test_successful_run_with_input_type_text(client, simple_api_test, crea # Now we check if the input_value is correct # We get text key twice because the output is now a Message assert all( - [output.get("results").get("text").get("text") == "value1" for output in text_input_outputs] + output.get("results").get("text").get("text") == "value1" for output in text_input_outputs ), text_input_outputs @@ -599,7 +598,7 @@ async def test_successful_run_with_input_type_chat(client: AsyncClient, simple_a assert len(chat_input_outputs) == 1 # Now we check if the input_value is correct assert all( - [output.get("results").get("message").get("text") == "value1" for output in chat_input_outputs] + output.get("results").get("message").get("text") == "value1" for output in chat_input_outputs ), chat_input_outputs @@ -653,7 +652,7 @@ async def test_successful_run_with_input_type_any(client, simple_api_test, creat result_dict.get("message", result_dict.get("text")) for result_dict in all_result_dicts ] assert all( - [message_or_text_dict.get("text") == "value1" for message_or_text_dict in all_message_or_text_dicts] + message_or_text_dict.get("text") == "value1" for message_or_text_dict in all_message_or_text_dicts ), any_input_outputs diff --git a/src/backend/tests/unit/test_experimental_components.py b/src/backend/tests/unit/test_experimental_components.py index 50a363f4a286..2c9a861d013b 100644 --- a/src/backend/tests/unit/test_experimental_components.py +++ b/src/backend/tests/unit/test_experimental_components.py @@ -1,5 +1,5 @@ -from langflow.components import prototypes import pytest +from langflow.components import prototypes @pytest.fixture diff --git a/src/backend/tests/unit/test_files.py b/src/backend/tests/unit/test_files.py index 33a5f5b707c1..13c306668427 100644 --- a/src/backend/tests/unit/test_files.py +++ b/src/backend/tests/unit/test_files.py @@ -9,10 +9,9 @@ import pytest from asgi_lifespan import LifespanManager from httpx import ASGITransport, AsyncClient -from sqlmodel import Session - from langflow.services.deps import get_storage_service from langflow.services.storage.service import StorageService +from sqlmodel import Session @pytest.fixture @@ -27,7 +26,7 @@ def mock_storage_service(): return service -@pytest.fixture(name="files_client", scope="function") +@pytest.fixture(name="files_client") async def files_client_fixture(session: Session, monkeypatch, request, load_flows_dir, mock_storage_service): # Set the database url to a test database if "noclient" in request.keywords: diff --git a/src/backend/tests/unit/test_frontend_nodes.py b/src/backend/tests/unit/test_frontend_nodes.py index fe9d71cc61b7..9cf19dd90090 100644 --- a/src/backend/tests/unit/test_frontend_nodes.py +++ b/src/backend/tests/unit/test_frontend_nodes.py @@ -1,5 +1,4 @@ import pytest - from langflow.template.field.base import Input from langflow.template.frontend_node.base import FrontendNode from langflow.template.template.base import Template diff --git a/src/backend/tests/unit/test_helper_components.py b/src/backend/tests/unit/test_helper_components.py index 9abe1c5588b1..673ea931e49e 100644 --- a/src/backend/tests/unit/test_helper_components.py +++ b/src/backend/tests/unit/test_helper_components.py @@ -1,8 +1,8 @@ +import pytest from langflow.components import helpers from langflow.custom.utils import build_custom_component_template from langflow.schema import Data from langflow.schema.message import Message -import pytest @pytest.fixture @@ -40,7 +40,7 @@ def client(): def test_uuid_generator_component(): # Arrange uuid_generator_component = helpers.IDGeneratorComponent() - uuid_generator_component._code = open(helpers.IDGenerator.__file__).read() + uuid_generator_component._code = open(helpers.IDGenerator.__file__, encoding="utf-8").read() frontend_node, _ = build_custom_component_template(uuid_generator_component) diff --git a/src/backend/tests/unit/test_initial_setup.py b/src/backend/tests/unit/test_initial_setup.py index 9110376d72e2..37a710814fbe 100644 --- a/src/backend/tests/unit/test_initial_setup.py +++ b/src/backend/tests/unit/test_initial_setup.py @@ -2,8 +2,6 @@ from pathlib import Path import pytest -from sqlmodel import select - from langflow.custom.directory_reader.utils import build_custom_component_list_from_path from langflow.initial_setup.setup import ( STARTER_FOLDER_NAME, @@ -14,6 +12,7 @@ from langflow.interface.types import aget_all_types_dict from langflow.services.database.models.folder.model import Folder from langflow.services.deps import session_scope +from sqlmodel import select def test_load_starter_projects(): @@ -100,10 +99,11 @@ async def test_create_or_update_starter_projects(): def find_componeny_by_name(components, name): - for category, children in components.items(): + for children in components.values(): if name in children: return children[name] - raise ValueError(f"Component {name} not found in components") + msg = f"Component {name} not found in components" + raise ValueError(msg) def set_value(component, input_name, value): diff --git a/src/backend/tests/unit/test_kubernetes_secrets.py b/src/backend/tests/unit/test_kubernetes_secrets.py index 1e48272cf34c..7fcc52ed9bb5 100644 --- a/src/backend/tests/unit/test_kubernetes_secrets.py +++ b/src/backend/tests/unit/test_kubernetes_secrets.py @@ -1,9 +1,9 @@ -import pytest -from unittest.mock import MagicMock -from kubernetes.client import V1ObjectMeta, V1Secret from base64 import b64encode +from unittest.mock import MagicMock from uuid import UUID +import pytest +from kubernetes.client import V1ObjectMeta, V1Secret from langflow.services.variable.kubernetes_secrets import KubernetesSecretManager, encode_user_id diff --git a/src/backend/tests/unit/test_loading.py b/src/backend/tests/unit/test_loading.py index eaff7d17faff..467b7cb8f3f0 100644 --- a/src/backend/tests/unit/test_loading.py +++ b/src/backend/tests/unit/test_loading.py @@ -1,5 +1,4 @@ import pytest - from langflow.graph import Graph from langflow.initial_setup.setup import load_starter_projects from langflow.load import load_flow_from_json @@ -27,9 +26,8 @@ def client(): def test_load_flow_from_json_object(): - """Test loading a flow from a json file and applying tweaks""" - _, projects = zip(*load_starter_projects()) - project = projects[0] + """Test loading a flow from a json file and applying tweaks.""" + project = load_starter_projects()[0][1] loaded = load_flow_from_json(project) assert loaded is not None assert isinstance(loaded, Graph) diff --git a/src/backend/tests/unit/test_logger.py b/src/backend/tests/unit/test_logger.py index 91fa66a75fc9..545f6a922dd2 100644 --- a/src/backend/tests/unit/test_logger.py +++ b/src/backend/tests/unit/test_logger.py @@ -1,7 +1,8 @@ -import pytest -import os import json +import os from unittest.mock import patch + +import pytest from langflow.logging.logger import SizedLogBuffer @@ -32,8 +33,8 @@ def test_write(sized_log_buffer): sized_log_buffer.max = 1 # Set max size to 1 for testing sized_log_buffer.write(message) assert len(sized_log_buffer.buffer) == 1 - assert 1625097600124 == sized_log_buffer.buffer[0][0] - assert "Test log" == sized_log_buffer.buffer[0][1] + assert sized_log_buffer.buffer[0][0] == 1625097600124 + assert sized_log_buffer.buffer[0][1] == "Test log" def test_write_overflow(sized_log_buffer): @@ -43,8 +44,8 @@ def test_write_overflow(sized_log_buffer): sized_log_buffer.write(message) assert len(sized_log_buffer.buffer) == 2 - assert 1625097601000 == sized_log_buffer.buffer[0][0] - assert 1625097602000 == sized_log_buffer.buffer[1][0] + assert sized_log_buffer.buffer[0][0] == 1625097601000 + assert sized_log_buffer.buffer[1][0] == 1625097602000 def test_len(sized_log_buffer): diff --git a/src/backend/tests/unit/test_login.py b/src/backend/tests/unit/test_login.py index b8e945ae1090..16864f3ca487 100644 --- a/src/backend/tests/unit/test_login.py +++ b/src/backend/tests/unit/test_login.py @@ -1,9 +1,8 @@ import pytest -from sqlalchemy.exc import IntegrityError - from langflow.services.auth.utils import get_password_hash from langflow.services.database.models.user import User from langflow.services.deps import session_scope +from sqlalchemy.exc import IntegrityError @pytest.fixture diff --git a/src/backend/tests/unit/test_messages.py b/src/backend/tests/unit/test_messages.py index 5ae53bb3440b..d062cbd21f23 100644 --- a/src/backend/tests/unit/test_messages.py +++ b/src/backend/tests/unit/test_messages.py @@ -1,5 +1,4 @@ import pytest - from langflow.memory import add_messages, add_messagetables, delete_messages, get_messages, store_message from langflow.schema.message import Message @@ -10,17 +9,16 @@ from langflow.services.tracing.utils import convert_to_langchain_type -@pytest.fixture() +@pytest.fixture def created_message(): with session_scope() as session: message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id") messagetable = MessageTable.model_validate(message, from_attributes=True) messagetables = add_messagetables([messagetable], session) - message_read = MessageRead.model_validate(messagetables[0], from_attributes=True) - return message_read + return MessageRead.model_validate(messagetables[0], from_attributes=True) -@pytest.fixture() +@pytest.fixture def created_messages(session): with session_scope() as session: messages = [ @@ -30,10 +28,7 @@ def created_messages(session): ] messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages] messagetables = add_messagetables(messagetables, session) - messages_read = [ - MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables - ] - return messages_read + return [MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables] def test_get_messages(): @@ -82,10 +77,10 @@ def test_convert_to_langchain(method_name): def convert(value): if method_name == "message": return value.to_lc_message() - elif method_name == "convert_to_langchain_type": + if method_name == "convert_to_langchain_type": return convert_to_langchain_type(value) - else: - raise ValueError(f"Invalid method: {method_name}") + msg = f"Invalid method: {method_name}" + raise ValueError(msg) lc_message = convert(Message(text="Test message 1", sender="User", sender_name="User", session_id="session_id2")) assert lc_message.content == "Test message 1" diff --git a/src/backend/tests/unit/test_messages_endpoints.py b/src/backend/tests/unit/test_messages_endpoints.py index 78c15970c451..0202a7ab5ace 100644 --- a/src/backend/tests/unit/test_messages_endpoints.py +++ b/src/backend/tests/unit/test_messages_endpoints.py @@ -2,7 +2,6 @@ import pytest from httpx import AsyncClient - from langflow.memory import add_messagetables # Assuming you have these imports available @@ -11,17 +10,16 @@ from langflow.services.deps import session_scope -@pytest.fixture() +@pytest.fixture async def created_message(): with session_scope() as session: message = MessageCreate(text="Test message", sender="User", sender_name="User", session_id="session_id") messagetable = MessageTable.model_validate(message, from_attributes=True) messagetables = add_messagetables([messagetable], session) - message_read = MessageRead.model_validate(messagetables[0], from_attributes=True) - return message_read + return MessageRead.model_validate(messagetables[0], from_attributes=True) -@pytest.fixture() +@pytest.fixture def created_messages(session): with session_scope() as session: messages = [ @@ -30,9 +28,7 @@ def created_messages(session): MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"), ] messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages] - message_list = add_messagetables(messagetables, session) - - return message_list + return add_messagetables(messagetables, session) @pytest.mark.api_key_required diff --git a/src/backend/tests/unit/test_process.py b/src/backend/tests/unit/test_process.py index 6b62fa2e9506..c9f3f101f4b2 100644 --- a/src/backend/tests/unit/test_process.py +++ b/src/backend/tests/unit/test_process.py @@ -1,5 +1,4 @@ import pytest - from langflow.processing.process import process_tweaks from langflow.services.deps import get_session_service diff --git a/src/backend/tests/unit/test_schema.py b/src/backend/tests/unit/test_schema.py index 68128ef28e34..9d94aaccb524 100644 --- a/src/backend/tests/unit/test_schema.py +++ b/src/backend/tests/unit/test_schema.py @@ -1,14 +1,13 @@ +from collections.abc import Sequence as SequenceABC from types import NoneType from typing import Union -from langflow.schema.data import Data import pytest -from pydantic import ValidationError - +from langflow.schema.data import Data from langflow.template import Input, Output from langflow.template.field.base import UNDEFINED from langflow.type_extraction.type_extraction import post_process_type -from collections.abc import Sequence as SequenceABC +from pydantic import ValidationError @pytest.fixture(name="client", autouse=True) @@ -73,8 +72,8 @@ def test_post_process_type_function(self): assert set(post_process_type(Union[None, list[None]])) == {None, NoneType} # Handling complex nested structures - assert set(post_process_type(Union[SequenceABC[Union[int, str]], list[float]])) == {int, str, float} - assert set(post_process_type(Union[Union[Union[int, list[str]], list[float]], str])) == {int, str, float} + assert set(post_process_type(Union[SequenceABC[int | str], list[float]])) == {int, str, float} + assert set(post_process_type(Union[int | list[str] | list[float], str])) == {int, str, float} # Non-generic types should return as is assert set(post_process_type(dict)) == {dict} @@ -92,12 +91,12 @@ def test_post_process_type_function(self): assert set(post_process_type(Data | Union[float, None])) == {Data, float, type(None)} # Multiple Data types combined - assert set(post_process_type(Union[Data, Union[str, float]])) == {Data, str, float} + assert set(post_process_type(Union[Data, str | float])) == {Data, str, float} assert set(post_process_type(Union[Data | float | str, int])) == {Data, int, float, str} # Testing with nested unions and lists - assert set(post_process_type(Union[list[Data], list[Union[int, str]]])) == {Data, int, str} - assert set(post_process_type(Data | list[Union[float, str]])) == {Data, float, str} + assert set(post_process_type(Union[list[Data], list[int | str]])) == {Data, int, str} + assert set(post_process_type(Data | list[float | str])) == {Data, float, str} def test_input_to_dict(self): input_obj = Input(field_type="str") diff --git a/src/backend/tests/unit/test_telemetry.py b/src/backend/tests/unit/test_telemetry.py index 6406770d67e3..618493c64664 100644 --- a/src/backend/tests/unit/test_telemetry.py +++ b/src/backend/tests/unit/test_telemetry.py @@ -1,8 +1,8 @@ -import pytest import threading -from langflow.services.telemetry.opentelemetry import OpenTelemetry from concurrent.futures import ThreadPoolExecutor, as_completed +import pytest +from langflow.services.telemetry.opentelemetry import OpenTelemetry fixed_labels = {"flow_id": "this_flow_id", "service": "this", "user": "that"} @@ -72,9 +72,9 @@ def test_missing_labels(opentelemetry_instance): with pytest.raises(ValueError, match="Labels must be provided for the metric"): opentelemetry_instance.up_down_counter("num_files_uploaded", 1, None) with pytest.raises(ValueError, match="Labels must be provided for the metric"): - opentelemetry_instance.update_gauge(metric_name="num_files_uploaded", value=1.0, labels=dict()) + opentelemetry_instance.update_gauge(metric_name="num_files_uploaded", value=1.0, labels={}) with pytest.raises(ValueError, match="Labels must be provided for the metric"): - opentelemetry_instance.observe_histogram("num_files_uploaded", 1, dict()) + opentelemetry_instance.observe_histogram("num_files_uploaded", 1, {}) def test_multithreaded_singleton(): diff --git a/src/backend/tests/unit/test_template.py b/src/backend/tests/unit/test_template.py index b9b38f2fc540..085c4afade78 100644 --- a/src/backend/tests/unit/test_template.py +++ b/src/backend/tests/unit/test_template.py @@ -12,13 +12,13 @@ def client(): # Dummy classes for testing purposes class Parent(BaseModel): - """Parent Class""" + """Parent Class.""" parent_field: str class Child(Parent): - """Child Class""" + """Child Class.""" child_field: int @@ -90,7 +90,7 @@ def dummy_function(): return "default_value" # Add dummy_function to your_module - setattr(importlib.import_module(module_name), "dummy_function", dummy_function) + importlib.import_module(module_name).dummy_function = dummy_function default_value = get_default_factory(module_name, function_repr) diff --git a/src/backend/tests/unit/test_user.py b/src/backend/tests/unit/test_user.py index 974884541330..fff14a9e354f 100644 --- a/src/backend/tests/unit/test_user.py +++ b/src/backend/tests/unit/test_user.py @@ -2,13 +2,12 @@ import pytest from httpx import AsyncClient -from sqlmodel import select - from langflow.services.auth.utils import create_super_user, get_password_hash from langflow.services.database.models.user import UserUpdate from langflow.services.database.models.user.model import User from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service, get_settings_service +from sqlmodel import select @pytest.fixture @@ -86,7 +85,7 @@ async def test_user_waiting_for_approval(client): with session_getter(get_db_service()) as session: existing_user = session.exec(select(User).where(User.username == username)).first() if existing_user: - print(f"User {username} still exists after the test. This is expected.") + pass else: pytest.fail(f"User {username} does not exist after the test. This is unexpected.") diff --git a/src/backend/tests/unit/test_validate_code.py b/src/backend/tests/unit/test_validate_code.py index 6d4e5c4215e7..a8203d0e66d4 100644 --- a/src/backend/tests/unit/test_validate_code.py +++ b/src/backend/tests/unit/test_validate_code.py @@ -2,9 +2,8 @@ from unittest import mock import pytest -from requests.exceptions import MissingSchema - from langflow.utils.validate import create_function, execute_function, extract_function_name, validate_code +from requests.exceptions import MissingSchema @pytest.fixture @@ -104,6 +103,5 @@ def test_execute_function_missing_schema(): def my_function(x): return requests.get(x).text """ - with mock.patch("requests.get", side_effect=MissingSchema): - with pytest.raises(MissingSchema): - execute_function(code, "my_function", "invalid_url") + with mock.patch("requests.get", side_effect=MissingSchema), pytest.raises(MissingSchema): + execute_function(code, "my_function", "invalid_url") diff --git a/src/backend/tests/unit/test_version.py b/src/backend/tests/unit/test_version.py index 36ac7c5a480e..d068b52987cd 100644 --- a/src/backend/tests/unit/test_version.py +++ b/src/backend/tests/unit/test_version.py @@ -9,9 +9,9 @@ def test_version(): def test_compute_main(): - assert "1.0.10" == _compute_non_prerelease_version("1.0.10.post0") - assert "1.0.10" == _compute_non_prerelease_version("1.0.10.a1") - assert "1.0.10" == _compute_non_prerelease_version("1.0.10.b112") - assert "1.0.10" == _compute_non_prerelease_version("1.0.10.rc0") - assert "1.0.10" == _compute_non_prerelease_version("1.0.10.dev9") - assert "1.0.10" == _compute_non_prerelease_version("1.0.10") + assert _compute_non_prerelease_version("1.0.10.post0") == "1.0.10" + assert _compute_non_prerelease_version("1.0.10.a1") == "1.0.10" + assert _compute_non_prerelease_version("1.0.10.b112") == "1.0.10" + assert _compute_non_prerelease_version("1.0.10.rc0") == "1.0.10" + assert _compute_non_prerelease_version("1.0.10.dev9") == "1.0.10" + assert _compute_non_prerelease_version("1.0.10") == "1.0.10" diff --git a/src/backend/tests/unit/utils/test_connection_string_parser.py b/src/backend/tests/unit/utils/test_connection_string_parser.py index 1ab82279e5f2..6b28c6847180 100644 --- a/src/backend/tests/unit/utils/test_connection_string_parser.py +++ b/src/backend/tests/unit/utils/test_connection_string_parser.py @@ -8,7 +8,7 @@ def client(): @pytest.mark.parametrize( - "connection_string, expected", + ("connection_string", "expected"), [ ("protocol:user:password@host", "protocol:user:password@host"), ("protocol:user@host", "protocol:user@host"), diff --git a/src/backend/tests/unit/utils/test_format_directory_path.py b/src/backend/tests/unit/utils/test_format_directory_path.py index 329383108ef2..16ff40080b88 100644 --- a/src/backend/tests/unit/utils/test_format_directory_path.py +++ b/src/backend/tests/unit/utils/test_format_directory_path.py @@ -3,7 +3,7 @@ @pytest.mark.parametrize( - "input_path, expected", + ("input_path", "expected"), [ # Test case 1: Standard path with no newlines (no change expected) ("/home/user/documents/file.txt", "/home/user/documents/file.txt"), diff --git a/src/backend/tests/unit/utils/test_rewrite_file_path.py b/src/backend/tests/unit/utils/test_rewrite_file_path.py index d5dc57013405..bb30280e2b8a 100644 --- a/src/backend/tests/unit/utils/test_rewrite_file_path.py +++ b/src/backend/tests/unit/utils/test_rewrite_file_path.py @@ -1,9 +1,9 @@ -from langflow.base.data.utils import format_directory_path import pytest +from langflow.base.data.utils import format_directory_path @pytest.mark.parametrize( - "input_path, expected", + ("input_path", "expected"), [ # Test case 1: Standard path with no newlines ("/home/user/documents/file.txt", "/home/user/documents/file.txt"), diff --git a/src/backend/tests/unit/utils/test_truncate_long_strings.py b/src/backend/tests/unit/utils/test_truncate_long_strings.py index 50d495fb486b..aa7ce3f958fd 100644 --- a/src/backend/tests/unit/utils/test_truncate_long_strings.py +++ b/src/backend/tests/unit/utils/test_truncate_long_strings.py @@ -1,9 +1,11 @@ +import math + import pytest from langflow.utils.util_strings import truncate_long_strings @pytest.mark.parametrize( - "input_data, max_length, expected", + ("input_data", "max_length", "expected"), [ # Test case 1: String shorter than max_length ("short string", 20, "short string"), @@ -20,7 +22,7 @@ # Test case 7: Integer input (12345, 3, 12345), # Test case 8: Float input - (3.14159, 4, 3.14159), + (math.pi, 4, math.pi), # Test case 9: Boolean input (True, 2, True), # Test case 10: None input diff --git a/src/backend/tests/unit/utils/test_truncate_long_strings_on_objects.py b/src/backend/tests/unit/utils/test_truncate_long_strings_on_objects.py index 93fad7f83e4b..eafc3f10b140 100644 --- a/src/backend/tests/unit/utils/test_truncate_long_strings_on_objects.py +++ b/src/backend/tests/unit/utils/test_truncate_long_strings_on_objects.py @@ -1,10 +1,10 @@ -from langflow.utils.util_strings import truncate_long_strings -from langflow.utils.constants import MAX_TEXT_LENGTH import pytest +from langflow.utils.constants import MAX_TEXT_LENGTH +from langflow.utils.util_strings import truncate_long_strings @pytest.mark.parametrize( - "input_data, max_length, expected", + ("input_data", "max_length", "expected"), [ # Test case 1: Simple string truncation ({"key": "a" * 100}, 10, {"key": "a" * 10 + "..."}),