From 391d55ca6dab15409d7253a7253056a6280fa674 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Sat, 19 Oct 2024 22:33:14 +0200 Subject: [PATCH] Add ALL ruff rules for tests --- pyproject.toml | 37 +++++++++ scripts/ci/pypi_nightly_tag.py | 12 +-- scripts/ci/update_lf_base_dependency.py | 14 ++-- scripts/ci/update_pyproject_name.py | 11 ++- scripts/ci/update_pyproject_version.py | 14 ++-- scripts/ci/update_uv_dependency.py | 9 ++- .../custom/custom_component/component.py | 2 +- src/backend/langflow/version/__init__.py | 1 + src/backend/langflow/version/version.py | 7 +- src/backend/tests/conftest.py | 38 ++++++--- src/backend/tests/data/component.py | 2 +- .../tests/data/component_nested_call.py | 2 +- .../data/component_with_templatefield.py | 2 +- .../test_starter_projects.py | 1 + .../assistants/test_assistants_components.py | 1 + .../components/astra/test_astra_component.py | 3 +- .../helpers/test_parse_json_data.py | 1 + .../components/inputs/test_chat_input.py | 1 + .../components/inputs/test_text_input.py | 1 + .../output_parsers/test_output_parser.py | 1 + .../components/outputs/test_chat_output.py | 1 + .../components/outputs/test_text_output.py | 1 + .../components/prompts/test_prompt.py | 1 + .../integration/flows/test_basic_prompting.py | 9 ++- src/backend/tests/integration/utils.py | 25 +++--- src/backend/tests/locust/locustfile.py | 2 +- src/backend/tests/unit/api/test_api_utils.py | 5 +- .../tests/unit/api/v1/test_variable.py | 20 ++--- src/backend/tests/unit/base/load/test_load.py | 3 +- .../unit/base/tools/test_component_toolkit.py | 4 +- .../test_structured_output_component.py | 50 +++++++++--- .../models/test_ChatOllama_component.py | 2 +- .../prompts/test_prompt_component.py | 2 +- .../prototypes/test_create_data_component.py | 6 +- .../prototypes/test_update_data_component.py | 7 +- .../custom/custom_component/test_component.py | 2 +- .../tests/unit/events/test_event_manager.py | 4 +- src/backend/tests/unit/exceptions/test_api.py | 46 ++++++----- .../graph/graph/state/test_state_model.py | 39 +++++----- .../tests/unit/graph/graph/test_base.py | 17 ++-- .../tests/unit/graph/graph/test_cycles.py | 14 ++-- .../graph/graph/test_graph_state_model.py | 23 ++---- .../graph/test_runnable_vertices_manager.py | 2 +- .../tests/unit/graph/graph/test_utils.py | 4 +- src/backend/tests/unit/graph/test_graph.py | 4 +- .../helpers/test_base_model_from_schema.py | 6 +- .../starter_projects/test_memory_chatbot.py | 1 + .../starter_projects/test_vector_store_rag.py | 22 +++--- src/backend/tests/unit/inputs/test_inputs.py | 4 +- src/backend/tests/unit/io/test_io_schema.py | 2 +- .../tests/unit/io/test_table_schema.py | 2 +- .../unit/services/variable/test_service.py | 40 +++------- src/backend/tests/unit/test_api_key.py | 6 +- .../tests/unit/test_custom_component.py | 78 +++++-------------- .../unit/test_custom_component_with_client.py | 5 +- src/backend/tests/unit/test_database.py | 22 +++--- src/backend/tests/unit/test_endpoints.py | 4 +- src/backend/tests/unit/test_files.py | 16 +++- src/backend/tests/unit/test_initial_setup.py | 26 ++++--- .../tests/unit/test_kubernetes_secrets.py | 4 +- src/backend/tests/unit/test_messages.py | 6 +- .../tests/unit/test_messages_endpoints.py | 6 +- src/backend/tests/unit/test_process.py | 20 +++-- src/backend/tests/unit/test_schema.py | 38 ++++----- .../tests/unit/test_setup_superuser.py | 6 +- src/backend/tests/unit/test_user.py | 19 +++-- src/backend/tests/unit/test_webhook.py | 2 +- 67 files changed, 424 insertions(+), 364 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f04f263db211..43471e982dd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -216,6 +216,43 @@ directory = "coverage" exclude = ["src/backend/langflow/alembic/*"] line-length = 120 +[tool.ruff.lint] +pydocstyle.convention = "google" +select = ["ALL"] +ignore = [ + "C90", # McCabe complexity + "CPY", # Missing copyright + "COM812", # Messes with the formatter + "ERA", # Eradicate commented-out code + "FIX002", # Line contains TODO + "ISC001", # Messes with the formatter + "PERF203", # Rarely useful + "PLR09", # Too many something (arg, statements, etc) + "RUF012", # Pydantic models are currently not well detected. See https://github.com/astral-sh/ruff/issues/13630 + "TD002", # Missing author in TODO + "TD003", # Missing issue link in TODO + "TRY301", # A bit too harsh (Abstract `raise` to an inner function) + + # Rules that are TODOs + "ANN", +] + +# Preview rules that are not yet activated +external = ["RUF027"] + +[tool.ruff.lint.per-file-ignores] +"scripts/*" = [ + "D1", + "INP", + "T201", +] +"src/backend/tests/*" = [ + "D1", + "PLR2004", + "S101", + "SLF001", +] + [tool.mypy] plugins = ["pydantic.mypy"] follow_imports = "skip" diff --git a/scripts/ci/pypi_nightly_tag.py b/scripts/ci/pypi_nightly_tag.py index 4ac02fc2cff0..d15dffc6c160 100755 --- a/scripts/ci/pypi_nightly_tag.py +++ b/scripts/ci/pypi_nightly_tag.py @@ -12,8 +12,10 @@ PYPI_LANGFLOW_BASE_URL = "https://pypi.org/pypi/langflow-base/json" PYPI_LANGFLOW_BASE_NIGHTLY_URL = "https://pypi.org/pypi/langflow-base-nightly/json" +ARGUMENT_NUMBER = 2 -def get_latest_published_version(build_type: str, is_nightly: bool) -> Version: + +def get_latest_published_version(build_type: str, *, is_nightly: bool) -> Version: import requests url = "" @@ -25,12 +27,12 @@ def get_latest_published_version(build_type: str, is_nightly: bool) -> Version: msg = f"Invalid build type: {build_type}" raise ValueError(msg) - res = requests.get(url) + res = requests.get(url, timeout=10) try: version_str = res.json()["info"]["version"] except Exception as e: msg = "Got unexpected response from PyPI" - raise RuntimeError(msg, e) + raise RuntimeError(msg) from e return Version(version_str) @@ -74,9 +76,9 @@ def create_tag(build_type: str): if __name__ == "__main__": - if len(sys.argv) != 2: + if len(sys.argv) != ARGUMENT_NUMBER: msg = "Specify base or main" - raise Exception(msg) + raise ValueError(msg) build_type = sys.argv[1] tag = create_tag(build_type) diff --git a/scripts/ci/update_lf_base_dependency.py b/scripts/ci/update_lf_base_dependency.py index 1de9dc201362..e3e5d1aabd20 100755 --- a/scripts/ci/update_lf_base_dependency.py +++ b/scripts/ci/update_lf_base_dependency.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python + import re import sys from pathlib import Path @@ -5,6 +7,7 @@ import packaging.version BASE_DIR = Path(__file__).parent.parent.parent +ARGUMENT_NUMBER = 2 def update_base_dep(pyproject_path: str, new_version: str) -> None: @@ -18,7 +21,7 @@ def update_base_dep(pyproject_path: str, new_version: str) -> None: pattern = re.compile(r'langflow-base = \{ path = "\./src/backend/base", develop = true \}') if not pattern.search(content): msg = f'langflow-base poetry dependency not found in "{filepath}"' - raise Exception(msg) + raise ValueError(msg) content = pattern.sub(replacement, content) filepath.write_text(content, encoding="utf-8") @@ -28,16 +31,13 @@ def verify_pep440(version): https://github.com/pypa/packaging/blob/16.7/packaging/version.py#L191 """ - try: - return packaging.version.Version(version) - except packaging.version.InvalidVersion: - raise + return packaging.version.Version(version) def main() -> None: - if len(sys.argv) != 2: + if len(sys.argv) != ARGUMENT_NUMBER: msg = "New version not specified" - raise Exception(msg) + raise ValueError(msg) base_version = sys.argv[1] # Strip "v" prefix from version if present diff --git a/scripts/ci/update_pyproject_name.py b/scripts/ci/update_pyproject_name.py index a11fc4087c95..38511bf4c14f 100755 --- a/scripts/ci/update_pyproject_name.py +++ b/scripts/ci/update_pyproject_name.py @@ -1,8 +1,11 @@ +#!/usr/bin/env python + import re import sys from pathlib import Path BASE_DIR = Path(__file__).parent.parent.parent +ARGUMENT_NUMBER = 3 def update_pyproject_name(pyproject_path: str, new_project_name: str) -> None: @@ -15,7 +18,7 @@ def update_pyproject_name(pyproject_path: str, new_project_name: str) -> None: if not pattern.search(content): msg = f'Project name not found in "{filepath}"' - raise Exception(msg) + raise ValueError(msg) content = pattern.sub(new_project_name, content) filepath.write_text(content, encoding="utf-8") @@ -39,15 +42,15 @@ def update_uv_dep(pyproject_path: str, new_project_name: str) -> None: # Updates the dependency name for uv if not pattern.search(content): msg = f"{replacement} uv dependency not found in {filepath}" - raise Exception(msg) + raise ValueError(msg) content = pattern.sub(replacement, content) filepath.write_text(content, encoding="utf-8") def main() -> None: - if len(sys.argv) != 3: + if len(sys.argv) != ARGUMENT_NUMBER: msg = "Must specify project name and build type, e.g. langflow-nightly base" - raise Exception(msg) + raise ValueError(msg) new_project_name = sys.argv[1] build_type = sys.argv[2] diff --git a/scripts/ci/update_pyproject_version.py b/scripts/ci/update_pyproject_version.py index 2b7d6021ac5b..79cbbdc6c35e 100755 --- a/scripts/ci/update_pyproject_version.py +++ b/scripts/ci/update_pyproject_version.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python + import re import sys from pathlib import Path @@ -5,6 +7,7 @@ import packaging.version BASE_DIR = Path(__file__).parent.parent.parent +ARGUMENT_NUMBER = 3 def update_pyproject_version(pyproject_path: str, new_version: str) -> None: @@ -17,7 +20,7 @@ def update_pyproject_version(pyproject_path: str, new_version: str) -> None: if not pattern.search(content): msg = f'Project version not found in "{filepath}"' - raise Exception(msg) + raise ValueError(msg) content = pattern.sub(new_version, content) @@ -29,16 +32,13 @@ def verify_pep440(version): https://github.com/pypa/packaging/blob/16.7/packaging/version.py#L191 """ - try: - return packaging.version.Version(version) - except packaging.version.InvalidVersion: - raise + return packaging.version.Version(version) def main() -> None: - if len(sys.argv) != 3: + if len(sys.argv) != ARGUMENT_NUMBER: msg = "New version not specified" - raise Exception(msg) + raise ValueError(msg) new_version = sys.argv[1] # Strip "v" prefix from version if present diff --git a/scripts/ci/update_uv_dependency.py b/scripts/ci/update_uv_dependency.py index 6452d30a1a15..c4ac2e809109 100755 --- a/scripts/ci/update_uv_dependency.py +++ b/scripts/ci/update_uv_dependency.py @@ -1,8 +1,11 @@ +#!/usr/bin/env python + import re import sys from pathlib import Path BASE_DIR = Path(__file__).parent.parent.parent +ARGUMENT_NUMBER = 2 def update_uv_dep(base_version: str) -> None: @@ -19,7 +22,7 @@ def update_uv_dep(base_version: str) -> None: # Check if the pattern is found if not pattern.search(content): msg = f"{pattern} UV dependency not found in {pyproject_path}" - raise Exception(msg) + raise ValueError(msg) # Replace the matched pattern with the new one content = pattern.sub(replacement, content) @@ -29,9 +32,9 @@ def update_uv_dep(base_version: str) -> None: def main() -> None: - if len(sys.argv) != 2: + if len(sys.argv) != ARGUMENT_NUMBER: msg = "specify base version" - raise Exception(msg) + raise ValueError(msg) base_version = sys.argv[1] base_version = base_version.lstrip("v") update_uv_dep(base_version) diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index fa9a0bb874ab..7f0c675e4887 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -12,7 +12,7 @@ from langflow.base.tools.constants import TOOL_OUTPUT_NAME from langflow.custom.tree_visitor import RequiredInputsVisitor -from langflow.field_typing import Tool # noqa: TCH001 Needed by add_toolkit_output +from langflow.field_typing import Tool # noqa: TCH001 Needed by _add_toolkit_output from langflow.graph.state.model import create_state_model from langflow.helpers.custom import format_type from langflow.schema.artifact import get_artifact_type, post_process_raw diff --git a/src/backend/langflow/version/__init__.py b/src/backend/langflow/version/__init__.py index e69de29bb2d1..7980c0451962 100644 --- a/src/backend/langflow/version/__init__.py +++ b/src/backend/langflow/version/__init__.py @@ -0,0 +1 @@ +"""Version package.""" diff --git a/src/backend/langflow/version/version.py b/src/backend/langflow/version/version.py index 2cd42e3977f1..377539fb6bcc 100644 --- a/src/backend/langflow/version/version.py +++ b/src/backend/langflow/version/version.py @@ -1,8 +1,11 @@ +"""Module for package versioning.""" + import contextlib def get_version() -> str: """Retrieves the version of the package from a possible list of package names. + This accounts for after package names are updated for -nightly builds. Returns: @@ -32,7 +35,9 @@ def get_version() -> str: def is_pre_release(v: str) -> bool: - """Returns a boolean indicating whether the version is a pre-release version, + """Returns a boolean indicating whether the version is a pre-release version. + + Returns a boolean indicating whether the version is a pre-release version, as per the definition of a pre-release segment from PEP 440. """ return any(label in v for label in ["a", "b", "rc"]) diff --git a/src/backend/tests/conftest.py b/src/backend/tests/conftest.py index 9857876e2530..5d80f9cc5862 100644 --- a/src/backend/tests/conftest.py +++ b/src/backend/tests/conftest.py @@ -28,7 +28,6 @@ from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service from loguru import logger -from pytest import LogCaptureFixture from sqlmodel import Session, SQLModel, create_engine, select from sqlmodel.pool import StaticPool from typer.testing import CliRunner @@ -102,7 +101,7 @@ def _delete_transactions_and_vertex_builds(session, user: User): @pytest.fixture -def caplog(caplog: LogCaptureFixture): +def caplog(caplog: pytest.LogCaptureFixture): handler_id = logger.add( caplog.handler, format="{message}", @@ -144,7 +143,7 @@ def load_flows_dir(): @pytest.fixture(name="distributed_env") -def setup_env(monkeypatch): +def _setup_env(monkeypatch): monkeypatch.setenv("LANGFLOW_CACHE_TYPE", "redis") monkeypatch.setenv("LANGFLOW_REDIS_HOST", "result_backend") monkeypatch.setenv("LANGFLOW_REDIS_PORT", "6379") @@ -158,7 +157,11 @@ def setup_env(monkeypatch): @pytest.fixture(name="distributed_client") -def distributed_client_fixture(session: Session, monkeypatch, distributed_env): +def distributed_client_fixture( + session: Session, # noqa: ARG001 + monkeypatch, + distributed_env, # noqa: ARG001 +): # Here we load the .env from ../deploy/.env from langflow.core import celery_app @@ -273,7 +276,12 @@ def json_memory_chatbot_no_llm(): @pytest.fixture(name="client") -async def client_fixture(session: Session, monkeypatch, request, load_flows_dir): +async def client_fixture( + session: Session, # noqa: ARG001 + monkeypatch, + request, + load_flows_dir, +): # Set the database url to a test database if "noclient" in request.keywords: yield @@ -296,9 +304,11 @@ async def client_fixture(session: Session, monkeypatch, request, load_flows_dir) db_service.database_url = f"sqlite:///{db_path}" db_service.reload_engine() # app.dependency_overrides[get_session] = get_session_override - async with LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager: - async with AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client: - yield client + async with ( + LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager, + AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client, + ): + yield client # app.dependency_overrides.clear() monkeypatch.undo() # clear the temp db @@ -308,7 +318,7 @@ async def client_fixture(session: Session, monkeypatch, request, load_flows_dir) # create a fixture for session_getter above @pytest.fixture(name="session_getter") -def session_getter_fixture(client): +def session_getter_fixture(client): # noqa: ARG001 @contextmanager def blank_session_getter(db_service: "DatabaseService"): with Session(db_service.engine) as session: @@ -326,7 +336,7 @@ def runner(): async def test_user(client): user_data = UserCreate( username="testuser", - password="testpassword", + password="testpassword", # noqa: S106 ) response = await client.post("api/v1/users/", json=user_data.model_dump()) assert response.status_code == 201 @@ -337,7 +347,7 @@ async def test_user(client): @pytest.fixture -def active_user(client): +def active_user(client): # noqa: ARG001 db_manager = get_db_service() with db_manager.with_session() as session: user = User( @@ -375,7 +385,11 @@ async def logged_in_headers(client, active_user): @pytest.fixture -def flow(client, json_flow: str, active_user): +def flow( + client, # noqa: ARG001 + json_flow: str, + active_user, +): from langflow.services.database.models.flow.model import FlowCreate loaded_json = json.loads(json_flow) diff --git a/src/backend/tests/data/component.py b/src/backend/tests/data/component.py index 5de63fcd5f40..e16ea3b4d2d5 100644 --- a/src/backend/tests/data/component.py +++ b/src/backend/tests/data/component.py @@ -7,7 +7,7 @@ class TestComponent(CustomComponent): def refresh_values(self): # This is a function that will be called every time the component is updated # and should return a list of random strings - return [f"Random {random.randint(1, 100)}" for _ in range(5)] + return [f"Random {random.randint(1, 100)}" for _ in range(5)] # noqa: S311 def build_config(self): return {"param": {"display_name": "Param", "options": self.refresh_values}} diff --git a/src/backend/tests/data/component_nested_call.py b/src/backend/tests/data/component_nested_call.py index 5dd61c2bab33..526d7cc88f9f 100644 --- a/src/backend/tests/data/component_nested_call.py +++ b/src/backend/tests/data/component_nested_call.py @@ -16,7 +16,7 @@ class MultipleOutputsComponent(Component): ] def certain_output(self) -> int: - return randint(0, self.number) + return randint(0, self.number) # noqa: S311 def other_output(self) -> int: return self.certain_output() diff --git a/src/backend/tests/data/component_with_templatefield.py b/src/backend/tests/data/component_with_templatefield.py index bc79c80d2ef1..cde77f717557 100644 --- a/src/backend/tests/data/component_with_templatefield.py +++ b/src/backend/tests/data/component_with_templatefield.py @@ -8,7 +8,7 @@ class TestComponent(CustomComponent): def refresh_values(self): # This is a function that will be called every time the component is updated # and should return a list of random strings - return [f"Random {random.randint(1, 100)}" for _ in range(5)] + return [f"Random {random.randint(1, 100)}" for _ in range(5)] # noqa: S311 def build_config(self): return {"param": Input(display_name="Param", options=self.refresh_values)} diff --git a/src/backend/tests/integration/backward_compatibility/test_starter_projects.py b/src/backend/tests/integration/backward_compatibility/test_starter_projects.py index f175e17eb630..d2319c56d704 100644 --- a/src/backend/tests/integration/backward_compatibility/test_starter_projects.py +++ b/src/backend/tests/integration/backward_compatibility/test_starter_projects.py @@ -1,5 +1,6 @@ import pytest from langflow.schema.message import Message + from tests.api_keys import get_openai_api_key from tests.integration.utils import download_flow_from_github, run_json_flow diff --git a/src/backend/tests/integration/components/assistants/test_assistants_components.py b/src/backend/tests/integration/components/assistants/test_assistants_components.py index a7b9f3b8a04b..019711b35148 100644 --- a/src/backend/tests/integration/components/assistants/test_assistants_components.py +++ b/src/backend/tests/integration/components/assistants/test_assistants_components.py @@ -1,4 +1,5 @@ import pytest + from tests.integration.utils import run_single_component diff --git a/src/backend/tests/integration/components/astra/test_astra_component.py b/src/backend/tests/integration/components/astra/test_astra_component.py index 06da3ec54cd8..fed8d1866132 100644 --- a/src/backend/tests/integration/components/astra/test_astra_component.py +++ b/src/backend/tests/integration/components/astra/test_astra_component.py @@ -6,6 +6,7 @@ from langflow.components.embeddings import OpenAIEmbeddingsComponent from langflow.components.vectorstores import AstraVectorStoreComponent from langflow.schema.data import Data + from tests.api_keys import get_astradb_api_endpoint, get_astradb_application_token, get_openai_api_key from tests.integration.components.mock_components import TextToData from tests.integration.utils import ComponentInputHandle, run_single_component @@ -27,7 +28,7 @@ @pytest.fixture -def astradb_client(request): +def astradb_client(): client = AstraDB(api_endpoint=get_astradb_api_endpoint(), token=get_astradb_application_token()) yield client for collection in ALL_COLLECTIONS: diff --git a/src/backend/tests/integration/components/helpers/test_parse_json_data.py b/src/backend/tests/integration/components/helpers/test_parse_json_data.py index 5c9931c97f96..48d5ded926f1 100644 --- a/src/backend/tests/integration/components/helpers/test_parse_json_data.py +++ b/src/backend/tests/integration/components/helpers/test_parse_json_data.py @@ -2,6 +2,7 @@ from langflow.components.helpers.ParseJSONData import ParseJSONDataComponent from langflow.components.inputs import ChatInput from langflow.schema import Data + from tests.integration.components.mock_components import TextToData from tests.integration.utils import ComponentInputHandle, run_single_component diff --git a/src/backend/tests/integration/components/inputs/test_chat_input.py b/src/backend/tests/integration/components/inputs/test_chat_input.py index f308685e6d8a..dd7840eefaf3 100644 --- a/src/backend/tests/integration/components/inputs/test_chat_input.py +++ b/src/backend/tests/integration/components/inputs/test_chat_input.py @@ -2,6 +2,7 @@ from langflow.components.inputs import ChatInput from langflow.memory import get_messages from langflow.schema.message import Message + from tests.integration.utils import run_single_component diff --git a/src/backend/tests/integration/components/inputs/test_text_input.py b/src/backend/tests/integration/components/inputs/test_text_input.py index c08d65067bd6..c4169c987043 100644 --- a/src/backend/tests/integration/components/inputs/test_text_input.py +++ b/src/backend/tests/integration/components/inputs/test_text_input.py @@ -1,6 +1,7 @@ import pytest from langflow.components.inputs import TextInputComponent from langflow.schema.message import Message + from tests.integration.utils import run_single_component diff --git a/src/backend/tests/integration/components/output_parsers/test_output_parser.py b/src/backend/tests/integration/components/output_parsers/test_output_parser.py index 01fa660e3259..74c7c2e03efe 100644 --- a/src/backend/tests/integration/components/output_parsers/test_output_parser.py +++ b/src/backend/tests/integration/components/output_parsers/test_output_parser.py @@ -4,6 +4,7 @@ from langflow.components.models.OpenAIModel import OpenAIModelComponent from langflow.components.output_parsers.OutputParser import OutputParserComponent from langflow.components.prompts.Prompt import PromptComponent + from tests.integration.utils import ComponentInputHandle, run_single_component diff --git a/src/backend/tests/integration/components/outputs/test_chat_output.py b/src/backend/tests/integration/components/outputs/test_chat_output.py index d5ca1de58508..27b4e42171ec 100644 --- a/src/backend/tests/integration/components/outputs/test_chat_output.py +++ b/src/backend/tests/integration/components/outputs/test_chat_output.py @@ -2,6 +2,7 @@ from langflow.components.outputs import ChatOutput from langflow.memory import get_messages from langflow.schema.message import Message + from tests.integration.utils import run_single_component diff --git a/src/backend/tests/integration/components/outputs/test_text_output.py b/src/backend/tests/integration/components/outputs/test_text_output.py index 87e027d33552..22076462e903 100644 --- a/src/backend/tests/integration/components/outputs/test_text_output.py +++ b/src/backend/tests/integration/components/outputs/test_text_output.py @@ -1,6 +1,7 @@ import pytest from langflow.components.outputs import TextOutputComponent from langflow.schema.message import Message + from tests.integration.utils import run_single_component diff --git a/src/backend/tests/integration/components/prompts/test_prompt.py b/src/backend/tests/integration/components/prompts/test_prompt.py index 744653269a17..1f52fd3b7079 100644 --- a/src/backend/tests/integration/components/prompts/test_prompt.py +++ b/src/backend/tests/integration/components/prompts/test_prompt.py @@ -1,6 +1,7 @@ import pytest from langflow.components.prompts import PromptComponent from langflow.schema.message import Message + from tests.integration.utils import run_single_component diff --git a/src/backend/tests/integration/flows/test_basic_prompting.py b/src/backend/tests/integration/flows/test_basic_prompting.py index 47298099919b..46dac6dec67d 100644 --- a/src/backend/tests/integration/flows/test_basic_prompting.py +++ b/src/backend/tests/integration/flows/test_basic_prompting.py @@ -4,18 +4,19 @@ from langflow.components.prompts import PromptComponent from langflow.graph import Graph from langflow.schema.message import Message + from tests.integration.utils import run_flow @pytest.mark.asyncio async def test_simple_no_llm(): graph = Graph() - input = graph.add_component(ChatInput()) - output = graph.add_component(ChatOutput()) + flow_input = graph.add_component(ChatInput()) + flow_output = graph.add_component(ChatOutput()) component = PromptComponent(template="This is the message: {var1}", var1="") prompt = graph.add_component(component) - graph.add_component_edge(input, ("message", "var1"), prompt) - graph.add_component_edge(prompt, ("prompt", "input_value"), output) + graph.add_component_edge(flow_input, ("message", "var1"), prompt) + graph.add_component_edge(prompt, ("prompt", "input_value"), flow_output) outputs = await run_flow(graph, run_input="hello!") assert isinstance(outputs["message"], Message) assert outputs["message"].text == "This is the message: hello!" diff --git a/src/backend/tests/integration/utils.py b/src/backend/tests/integration/utils.py index 7bc88251f3ba..d8d2a1141d9d 100644 --- a/src/backend/tests/integration/utils.py +++ b/src/backend/tests/integration/utils.py @@ -12,26 +12,26 @@ from langflow.processing.process import run_graph_internal -def check_env_vars(*vars): +def check_env_vars(*env_vars): """Check if all specified environment variables are set. Args: - *vars (str): The environment variables to check. + *env_vars (str): The environment variables to check. Returns: bool: True if all environment variables are set, False otherwise. """ - return all(os.getenv(var) for var in vars) + return all(os.getenv(var) for var in env_vars) def valid_nvidia_vectorize_region(api_endpoint: str) -> bool: """Check if the specified region is valid. Args: - region (str): The region to check. + api_endpoint: The API endpoint to check. Returns: - bool: True if the region is contains hosted nvidia models, False otherwise. + True if the region contains hosted nvidia models, False otherwise. """ parsed_endpoint = parse_api_endpoint(api_endpoint) if not parsed_endpoint: @@ -63,12 +63,12 @@ class JSONFlow: json: dict def get_components_by_type(self, component_type): - result = [] - for node in self.json["data"]["nodes"]: - if node["data"]["type"] == component_type: - result.append(node["id"]) + result = [node["id"] for node in self.json["data"]["nodes"] if node["data"]["type"] == component_type] if not result: - msg = f"Component of type {component_type} not found, available types: {', '.join({node['data']['type'] for node in self.json['data']['nodes']})}" + msg = ( + f"Component of type {component_type} not found, " + f"available types: {', '.join({node['data']['type'] for node in self.json['data']['nodes']})}" + ) raise ValueError(msg) return result @@ -97,7 +97,8 @@ def set_value(self, component_id, key, value): def download_flow_from_github(name: str, version: str) -> JSONFlow: response = requests.get( - f"https://raw.githubusercontent.com/langflow-ai/langflow/v{version}/src/backend/base/langflow/initial_setup/starter_projects/{name}.json" + f"https://raw.githubusercontent.com/langflow-ai/langflow/v{version}/src/backend/base/langflow/initial_setup/starter_projects/{name}.json", + timeout=10, ) response.raise_for_status() as_json = response.json() @@ -151,7 +152,7 @@ def _add_component(clazz: type, inputs: dict | None = None) -> str: raw_inputs[key] = value if isinstance(value, Component): msg = "Component inputs must be wrapped in ComponentInputHandle" - raise ValueError(msg) + raise TypeError(msg) component = clazz(**raw_inputs, _user_id=user_id) component_id = graph.add_component(component) if inputs: diff --git a/src/backend/tests/locust/locustfile.py b/src/backend/tests/locust/locustfile.py index c85e62c1a65d..f7857e619e99 100644 --- a/src/backend/tests/locust/locustfile.py +++ b/src/backend/tests/locust/locustfile.py @@ -60,7 +60,7 @@ def process(self, name, flow_id, payload): @task def send_name_and_check(self): - name = random.choice(self.names) + name = random.choice(self.names) # noqa: S311 payload1 = { "inputs": {"text": f"Hello, My name is {name}"}, diff --git a/src/backend/tests/unit/api/test_api_utils.py b/src/backend/tests/unit/api/test_api_utils.py index d992f10fe591..f87befdaf4f6 100644 --- a/src/backend/tests/unit/api/test_api_utils.py +++ b/src/backend/tests/unit/api/test_api_utils.py @@ -17,7 +17,10 @@ def test_get_suggestion_message(): # Test case 3: Multiple outdated components outdated_components = ["component1", "component2", "component3"] - expected_message = "The flow contains 3 outdated components. We recommend updating the following components: component1, component2, component3." + expected_message = ( + "The flow contains 3 outdated components. " + "We recommend updating the following components: component1, component2, component3." + ) assert get_suggestion_message(outdated_components) == expected_message diff --git a/src/backend/tests/unit/api/v1/test_variable.py b/src/backend/tests/unit/api/v1/test_variable.py index 26f1b142196c..6c834542a4b7 100644 --- a/src/backend/tests/unit/api/v1/test_variable.py +++ b/src/backend/tests/unit/api/v1/test_variable.py @@ -75,7 +75,7 @@ async def test_create_variable__variable_value_cannot_be_empty(client: AsyncClie @pytest.mark.usefixtures("active_user") -async def test_create_variable__HTTPException(client: AsyncClient, body, logged_in_headers): +async def test_create_variable__httpexception(client: AsyncClient, body, logged_in_headers): status_code = 418 generic_message = "I'm a teapot" @@ -89,7 +89,7 @@ async def test_create_variable__HTTPException(client: AsyncClient, body, logged_ @pytest.mark.usefixtures("active_user") -async def test_create_variable__Exception(client: AsyncClient, body, logged_in_headers): +async def test_create_variable__exception(client: AsyncClient, body, logged_in_headers): generic_message = "Generic error message" with mock.patch("langflow.services.auth.utils.encrypt_api_key") as m: @@ -133,16 +133,10 @@ async def test_read_variables__empty(client: AsyncClient, logged_in_headers): async def test_read_variables__(client: AsyncClient, logged_in_headers): generic_message = "Generic error message" - with pytest.raises(Exception) as exc, mock.patch("sqlmodel.Session.exec") as m: + with mock.patch("sqlmodel.Session.exec") as m: m.side_effect = Exception(generic_message) - - response = await client.get("api/v1/variables/", headers=logged_in_headers) - result = response.json() - - assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - assert generic_message in result["detail"] - - assert generic_message in str(exc.value) + with pytest.raises(Exception, match=generic_message): + await client.get("api/v1/variables/", headers=logged_in_headers) @pytest.mark.usefixtures("active_user") @@ -165,7 +159,7 @@ async def test_update_variable(client: AsyncClient, body, logged_in_headers): @pytest.mark.usefixtures("active_user") -async def test_update_variable__Exception(client: AsyncClient, body, logged_in_headers): +async def test_update_variable__exception(client: AsyncClient, body, logged_in_headers): wrong_id = uuid4() body["id"] = str(wrong_id) @@ -186,7 +180,7 @@ async def test_delete_variable(client: AsyncClient, body, logged_in_headers): @pytest.mark.usefixtures("active_user") -async def test_delete_variable__Exception(client: AsyncClient, logged_in_headers): +async def test_delete_variable__exception(client: AsyncClient, logged_in_headers): wrong_id = uuid4() response = await client.delete(f"api/v1/variables/{wrong_id}", headers=logged_in_headers) diff --git a/src/backend/tests/unit/base/load/test_load.py b/src/backend/tests/unit/base/load/test_load.py index f59ddfd1606c..4c7927a1044f 100644 --- a/src/backend/tests/unit/base/load/test_load.py +++ b/src/backend/tests/unit/base/load/test_load.py @@ -26,4 +26,5 @@ def test_run_flow_from_json_params(): params = func_spec.args + func_spec.kwonlyargs assert expected_params.issubset(params), "Not all expected parameters are present in run_flow_from_json" - # TODO: Add tests by loading a flow and running it need to text with fake llm and check if it returns the correct output + # TODO: Add tests by loading a flow and running it need to text with fake llm and check if it returns the + # correct output diff --git a/src/backend/tests/unit/base/tools/test_component_toolkit.py b/src/backend/tests/unit/base/tools/test_component_toolkit.py index 85d0d73c5165..07b9694fb4d5 100644 --- a/src/backend/tests/unit/base/tools/test_component_toolkit.py +++ b/src/backend/tests/unit/base/tools/test_component_toolkit.py @@ -12,7 +12,7 @@ @pytest.fixture -def add_toolkit_output(): +def _add_toolkit_output(): FEATURE_FLAGS.add_toolkit_output = True yield FEATURE_FLAGS.add_toolkit_output = False @@ -81,7 +81,7 @@ def test_component_tool(): @pytest.mark.api_key_required -@pytest.mark.usefixtures("add_toolkit_output", "client") +@pytest.mark.usefixtures("_add_toolkit_output", "client") def test_component_tool_with_api_key(): chat_output = ChatOutput() openai_llm = OpenAIModelComponent() diff --git a/src/backend/tests/unit/components/helpers/test_structured_output_component.py b/src/backend/tests/unit/components/helpers/test_structured_output_component.py index 16f4879355c4..bce79b4b5fef 100644 --- a/src/backend/tests/unit/components/helpers/test_structured_output_component.py +++ b/src/backend/tests/unit/components/helpers/test_structured_output_component.py @@ -1,34 +1,60 @@ from unittest.mock import MagicMock, patch import pytest +from langchain_core.language_models import BaseLanguageModel from langflow.components.helpers.structured_output import StructuredOutputComponent from langflow.schema.data import Data from pydantic import BaseModel - - -@pytest.fixture -def client(): - pass +from typing_extensions import override class TestStructuredOutputComponent: - # Ensure that the structured output is successfully generated with the correct BaseModel instance returned by the mock function + # Ensure that the structured output is successfully generated with the correct BaseModel instance returned by + # the mock function def test_successful_structured_output_generation_with_patch_with_config(self): from unittest.mock import patch - class MockLanguageModel: - def with_structured_output(self, schema): + class MockLanguageModel(BaseLanguageModel): + @override + def with_structured_output(self, *args, **kwargs): return self - def with_config(self, config): + @override + def with_config(self, *args, **kwargs): return self - def invoke(self, inputs): + @override + def invoke(self, *args, **kwargs): return self - def mock_get_chat_result(runnable, input_value, config): + @override + def generate_prompt(self, *args, **kwargs): + raise NotImplementedError + + @override + async def agenerate_prompt(self, *args, **kwargs): + raise NotImplementedError + + @override + def predict(self, *args, **kwargs): + raise NotImplementedError + + @override + def predict_messages(self, *args, **kwargs): + raise NotImplementedError + + @override + async def apredict(self, *args, **kwargs): + raise NotImplementedError + + @override + async def apredict_messages(self, *args, **kwargs): + raise NotImplementedError + + def mock_get_chat_result(runnable, input_value, config): # noqa: ARG001 class MockBaseModel(BaseModel): - def model_dump(self): + @override + def model_dump(self, **kwargs): return {"field": "value"} return MockBaseModel() diff --git a/src/backend/tests/unit/components/models/test_ChatOllama_component.py b/src/backend/tests/unit/components/models/test_ChatOllama_component.py index 6c5060b93690..2970ba4ff937 100644 --- a/src/backend/tests/unit/components/models/test_ChatOllama_component.py +++ b/src/backend/tests/unit/components/models/test_ChatOllama_component.py @@ -112,7 +112,7 @@ def test_update_build_config_keep_alive(component): "langchain_community.chat_models.ChatOllama", return_value=ChatOllama(base_url="http://localhost:11434", model="llama3.1"), ) -def test_build_model(_mock_chat_ollama, component): +def test_build_model(_mock_chat_ollama, component): # noqa: PT019 component.base_url = "http://localhost:11434" component.model_name = "llama3.1" component.mirostat = "Mirostat 2.0" diff --git a/src/backend/tests/unit/components/prompts/test_prompt_component.py b/src/backend/tests/unit/components/prompts/test_prompt_component.py index 0efbd93e7655..2c1f5cad2cd9 100644 --- a/src/backend/tests/unit/components/prompts/test_prompt_component.py +++ b/src/backend/tests/unit/components/prompts/test_prompt_component.py @@ -1,4 +1,4 @@ -from langflow.components.prompts.Prompt import PromptComponent # type: ignore +from langflow.components.prompts.Prompt import PromptComponent class TestPromptComponent: diff --git a/src/backend/tests/unit/components/prototypes/test_create_data_component.py b/src/backend/tests/unit/components/prototypes/test_create_data_component.py index d5dc01a5c8e3..f97b0794f5fd 100644 --- a/src/backend/tests/unit/components/prototypes/test_create_data_component.py +++ b/src/backend/tests/unit/components/prototypes/test_create_data_component.py @@ -109,9 +109,5 @@ def test_validate_text_key_invalid(create_data_component): create_data_component.text_key = "invalid_key" # Act & Assert - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValueError, match="Text Key: 'invalid_key' not found in the Data keys: 'key1, key2'"): create_data_component.validate_text_key() - - # Check for the exact error message - expected_error_message = f"Text Key: '{create_data_component.text_key}' not found in the Data keys: '{', '.join(create_data_component.get_data().keys())}'" - assert str(exc_info.value) == expected_error_message diff --git a/src/backend/tests/unit/components/prototypes/test_update_data_component.py b/src/backend/tests/unit/components/prototypes/test_update_data_component.py index 36f44edd54bd..d60bb4c9c5cd 100644 --- a/src/backend/tests/unit/components/prototypes/test_update_data_component.py +++ b/src/backend/tests/unit/components/prototypes/test_update_data_component.py @@ -96,10 +96,5 @@ def test_validate_text_key_invalid(update_data_component): data = Data(data={"key1": "value1", "key2": "value2"}, text_key="key1") update_data_component.text_key = "invalid_key" - with pytest.raises(ValueError) as exc_info: + with pytest.raises(ValueError, match="Text Key: invalid_key not found in the Data keys: key1,key2"): update_data_component.validate_text_key(data) - - expected_error_message = ( - f"Text Key: {update_data_component.text_key} not found in the Data keys: {','.join(data.data.keys())}" - ) - assert str(exc_info.value) == expected_error_message diff --git a/src/backend/tests/unit/custom/custom_component/test_component.py b/src/backend/tests/unit/custom/custom_component/test_component.py index ce567d5df7dd..2dfa9a0220d2 100644 --- a/src/backend/tests/unit/custom/custom_component/test_component.py +++ b/src/backend/tests/unit/custom/custom_component/test_component.py @@ -11,7 +11,7 @@ def test_set_invalid_output(): chatinput = ChatInput() chatoutput = ChatOutput() - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Method build_config is not a valid output of ChatInput"): chatoutput.set(input_value=chatinput.build_config) diff --git a/src/backend/tests/unit/events/test_event_manager.py b/src/backend/tests/unit/events/test_event_manager.py index 09bcc4eecdaa..b02e29d2c316 100644 --- a/src/backend/tests/unit/events/test_event_manager.py +++ b/src/backend/tests/unit/events/test_event_manager.py @@ -90,9 +90,9 @@ def mock_callback(event_type, data): queue = asyncio.Queue() manager = EventManager(queue) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Event name cannot be empty"): manager.register_event("", "test_type", mock_callback) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Event name must start with 'on_'"): manager.register_event("invalid_name", "test_type", mock_callback) # Sending an event with complex data and verifying successful event transmission diff --git a/src/backend/tests/unit/exceptions/test_api.py b/src/backend/tests/unit/exceptions/test_api.py index 38cf5f57f36a..9934eb7ce46b 100644 --- a/src/backend/tests/unit/exceptions/test_api.py +++ b/src/backend/tests/unit/exceptions/test_api.py @@ -16,28 +16,32 @@ def test_api_exception(): } # Expected result - with patch( - "langflow.services.database.models.flow.utils.get_outdated_components", return_value=mock_outdated_components + with ( + patch( + "langflow.services.database.models.flow.utils.get_outdated_components", + return_value=mock_outdated_components, + ), + patch("langflow.api.utils.get_suggestion_message", return_value=mock_suggestion_message), + patch( + "langflow.services.database.models.flow.utils.get_components_versions", + return_value=mock_component_versions, + ), ): - with patch("langflow.api.utils.get_suggestion_message", return_value=mock_suggestion_message): - with patch( - "langflow.services.database.models.flow.utils.get_components_versions", - return_value=mock_component_versions, - ): - # Create an APIException instance - api_exception = APIException(mock_exception, mock_flow) - - # Expected body - expected_body = ExceptionBody( - message="Test exception", - suggestion="The flow contains 2 outdated components. We recommend updating the following components: component1, component2.", - ) - - # Assert the status code - assert api_exception.status_code == 500 - - # Assert the detail - assert api_exception.detail == expected_body.model_dump_json() + # Create an APIException instance + api_exception = APIException(mock_exception, mock_flow) + + # Expected body + expected_body = ExceptionBody( + message="Test exception", + suggestion="The flow contains 2 outdated components. " + "We recommend updating the following components: component1, component2.", + ) + + # Assert the status code + assert api_exception.status_code == 500 + + # Assert the detail + assert api_exception.detail == expected_body.model_dump_json() def test_api_exception_no_flow(): diff --git a/src/backend/tests/unit/graph/graph/state/test_state_model.py b/src/backend/tests/unit/graph/graph/state/test_state_model.py index dfa89a472ec8..e5371c39bdf8 100644 --- a/src/backend/tests/unit/graph/graph/state/test_state_model.py +++ b/src/backend/tests/unit/graph/graph/state/test_state_model.py @@ -22,27 +22,27 @@ class TestCreateStateModel: # Successfully create a model with valid method return type annotations def test_create_model_with_valid_return_type_annotations(self, chat_input_component): - StateModel = create_state_model(method_one=chat_input_component.message_response) + state_model = create_state_model(method_one=chat_input_component.message_response) - state_instance = StateModel() + state_instance = state_model() assert state_instance.method_one is UNDEFINED chat_input_component.set_output_value("message", "test") assert state_instance.method_one == "test" def test_create_model_and_assign_values_fails(self, chat_input_component): - StateModel = create_state_model(method_one=chat_input_component.message_response) + state_model = create_state_model(method_one=chat_input_component.message_response) - state_instance = StateModel() + state_instance = state_model() state_instance.method_one = "test" assert state_instance.method_one == "test" def test_create_with_multiple_components(self, chat_input_component, chat_output_component): - NewStateModel = create_state_model( + new_state_model = create_state_model( model_name="NewStateModel", first_method=chat_input_component.message_response, second_method=chat_output_component.message_response, ) - state_instance = NewStateModel() + state_instance = new_state_model() assert state_instance.first_method is UNDEFINED assert state_instance.second_method is UNDEFINED state_instance.first_method = "test" @@ -51,9 +51,9 @@ def test_create_with_multiple_components(self, chat_input_component, chat_output assert state_instance.second_method == 123 def test_create_with_pydantic_field(self, chat_input_component): - StateModel = create_state_model(method_one=chat_input_component.message_response, my_attribute=Field(None)) + state_model = create_state_model(method_one=chat_input_component.message_response, my_attribute=Field(None)) - state_instance = StateModel() + state_instance = state_model() state_instance.method_one = "test" state_instance.my_attribute = "test" assert state_instance.method_one == "test" @@ -64,8 +64,8 @@ def test_create_with_pydantic_field(self, chat_input_component): # Creates a model with fields based on provided keyword arguments def test_create_model_with_fields_from_kwargs(self): - StateModel = create_state_model(field_one=(str, "default"), field_two=(int, 123)) - state_instance = StateModel() + state_model = create_state_model(field_one=(str, "default"), field_two=(int, 123)) + state_instance = state_model() assert state_instance.field_one == "default" assert state_instance.field_two == 123 @@ -81,16 +81,16 @@ def test_raise_valueerror_for_unsupported_value_types(self): # Handles empty keyword arguments gracefully def test_handle_empty_kwargs_gracefully(self): - StateModel = create_state_model() - state_instance = StateModel() + state_model = create_state_model() + state_instance = state_model() assert state_instance is not None # Ensures model name defaults to "State" if not provided def test_default_model_name_to_state(self): - StateModel = create_state_model() - assert StateModel.__name__ == "State" - OtherNameModel = create_state_model(model_name="OtherName") - assert OtherNameModel.__name__ == "OtherName" + state_model = create_state_model() + assert state_model.__name__ == "State" + other_name_model = create_state_model(model_name="OtherName") + assert other_name_model.__name__ == "OtherName" # Validates that callable values are properly type-annotated @@ -110,8 +110,7 @@ def test_graph_functional_start_state_update(self): chat_input = ChatInput(_id="chat_input") chat_output = ChatOutput(input_value="test", _id="chat_output") chat_output.set(sender_name=chat_input.message_response) - ChatStateModel = create_state_model(model_name="ChatState", message=chat_output.message_response) - chat_state_model = ChatStateModel() + chat_state_model = create_state_model(model_name="ChatState", message=chat_output.message_response)() assert chat_state_model.__class__.__name__ == "ChatState" assert chat_state_model.message is UNDEFINED @@ -121,9 +120,7 @@ def test_graph_functional_start_state_update(self): # and check that the graph is running # correctly ids = ["chat_input", "chat_output"] - results = [] - for result in graph.start(): - results.append(result) + results = list(graph.start()) assert len(results) == 3 assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex")) diff --git a/src/backend/tests/unit/graph/graph/test_base.py b/src/backend/tests/unit/graph/graph/test_base.py index df13063ea551..3d028bc84f7f 100644 --- a/src/backend/tests/unit/graph/graph/test_base.py +++ b/src/backend/tests/unit/graph/graph/test_base.py @@ -9,7 +9,6 @@ from langflow.components.tools.YfinanceTool import YfinanceToolComponent from langflow.graph.graph.base import Graph from langflow.graph.graph.constants import Finish -from pytest import LogCaptureFixture @pytest.mark.asyncio @@ -19,12 +18,12 @@ async def test_graph_not_prepared(): graph = Graph() graph.add_component(chat_input) graph.add_component(chat_output) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Graph not prepared"): await graph.astep() @pytest.mark.asyncio -async def test_graph(caplog: LogCaptureFixture): +async def test_graph(caplog: pytest.LogCaptureFixture): chat_input = ChatInput() chat_output = ChatOutput() graph = Graph() @@ -83,9 +82,7 @@ async def test_graph_functional_async_start(): # and check that the graph is running # correctly ids = ["chat_input", "chat_output"] - results = [] - async for result in graph.async_start(): - results.append(result) + results = [result async for result in graph.async_start()] assert len(results) == 3 assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex")) @@ -102,9 +99,7 @@ def test_graph_functional_start(): # and check that the graph is running # correctly ids = ["chat_input", "chat_output"] - results = [] - for result in graph.start(): - results.append(result) + results = list(graph.start()) assert len(results) == 3 assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex")) @@ -123,9 +118,7 @@ def test_graph_functional_start_end(): # and check that the graph is running # correctly ids = ["chat_input", "text_output"] - results = [] - for result in graph.start(): - results.append(result) + results = list(graph.start()) assert len(results) == len(ids) + 1 assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex")) diff --git a/src/backend/tests/unit/graph/graph/test_cycles.py b/src/backend/tests/unit/graph/graph/test_cycles.py index 2530d07bac50..3a15d292817a 100644 --- a/src/backend/tests/unit/graph/graph/test_cycles.py +++ b/src/backend/tests/unit/graph/graph/test_cycles.py @@ -103,11 +103,9 @@ def test_cycle_in_graph_max_iterations(): # Run queue should contain chat_input and not router assert "chat_input" in graph._run_queue assert "router" not in graph._run_queue - results = [] with pytest.raises(ValueError, match="Max iterations reached"): - for result in graph.start(max_iterations=2, config={"output": {"cache": False}}): - results.append(result) + list(graph.start(max_iterations=2, config={"output": {"cache": False}})) def test_that_outputs_cache_is_set_to_false_in_cycle(): @@ -149,7 +147,10 @@ def test_updated_graph_with_prompts(): # First prompt: Guessing game with hints prompt_component_1 = PromptComponent(_id="prompt_component_1").set( - template="Try to guess a word. I will give you hints if you get it wrong.\nHint: {hint}\nLast try: {last_try}\nAnswer:", + template="Try to guess a word. I will give you hints if you get it wrong.\n" + "Hint: {hint}\n" + "Last try: {last_try}\n" + "Answer:", ) # First OpenAI LLM component (Processes the guessing prompt) @@ -168,7 +169,10 @@ def test_updated_graph_with_prompts(): # Second prompt: After the last try, provide a new hint prompt_component_2 = PromptComponent(_id="prompt_component_2") prompt_component_2.set( - template="Given the following word and the following last try. Give the guesser a new hint.\nLast try: {last_try}\nWord: {word}\nHint:", + template="Given the following word and the following last try. Give the guesser a new hint.\n" + "Last try: {last_try}\n" + "Word: {word}\n" + "Hint:", word=chat_input.message_response, last_try=router.false_response, ) diff --git a/src/backend/tests/unit/graph/graph/test_graph_state_model.py b/src/backend/tests/unit/graph/graph/test_graph_state_model.py index 6fb714fa5652..ece9fd113c9f 100644 --- a/src/backend/tests/unit/graph/graph/test_graph_state_model.py +++ b/src/backend/tests/unit/graph/graph/test_graph_state_model.py @@ -38,9 +38,9 @@ def test_graph_state_model(): graph = Graph(chat_input, chat_output) - GraphStateModel = create_state_model_from_graph(graph) - assert GraphStateModel.__name__ == "GraphStateModel" - assert list(GraphStateModel.model_computed_fields.keys()) == [ + graph_state_model = create_state_model_from_graph(graph) + assert graph_state_model.__name__ == "GraphStateModel" + assert list(graph_state_model.model_computed_fields.keys()) == [ "chat_input", "chat_output", "openai", @@ -60,12 +60,9 @@ def test_graph_functional_start_graph_state_update(): # Now iterate through the graph # and check that the graph is running # correctly - GraphStateModel = create_state_model_from_graph(graph) - graph_state_model = GraphStateModel() + graph_state_model = create_state_model_from_graph(graph)() ids = ["chat_input", "chat_output"] - results = [] - for result in graph.start(): - results.append(result) + results = list(graph.start()) assert len(results) == 3 assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex")) @@ -87,12 +84,9 @@ def test_graph_state_model_serialization(): # Now iterate through the graph # and check that the graph is running # correctly - GraphStateModel = create_state_model_from_graph(graph) - graph_state_model = GraphStateModel() + graph_state_model = create_state_model_from_graph(graph)() ids = ["chat_input", "chat_output"] - results = [] - for result in graph.start(): - results.append(result) + results = list(graph.start()) assert len(results) == 3 assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex")) @@ -116,8 +110,7 @@ def test_graph_state_model_json_schema(): graph = Graph(chat_input, chat_output) graph.prepare() - GraphStateModel = create_state_model_from_graph(graph) - graph_state_model: BaseModel = GraphStateModel() + graph_state_model: BaseModel = create_state_model_from_graph(graph)() json_schema = graph_state_model.model_json_schema(mode="serialization") # Test main schema structure diff --git a/src/backend/tests/unit/graph/graph/test_runnable_vertices_manager.py b/src/backend/tests/unit/graph/graph/test_runnable_vertices_manager.py index 4ff3151d536f..188d19b915ca 100644 --- a/src/backend/tests/unit/graph/graph/test_runnable_vertices_manager.py +++ b/src/backend/tests/unit/graph/graph/test_runnable_vertices_manager.py @@ -66,7 +66,7 @@ def test_pickle(data): manager = RunnableVerticesManager.from_dict(data) binary = pickle.dumps(manager) - result = pickle.loads(binary) + result = pickle.loads(binary) # noqa: S301 assert result.run_map == manager.run_map assert result.run_predecessors == manager.run_predecessors diff --git a/src/backend/tests/unit/graph/graph/test_utils.py b/src/backend/tests/unit/graph/graph/test_utils.py index 6bf09542a03d..982f73118af2 100644 --- a/src/backend/tests/unit/graph/graph/test_utils.py +++ b/src/backend/tests/unit/graph/graph/test_utils.py @@ -119,7 +119,7 @@ def test_sort_up_to_vertex_a(graph): def test_sort_up_to_vertex_invalid_vertex(graph): vertex_id = "7" - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Parent node map is required to find the root of a group node"): utils.sort_up_to_vertex(graph, vertex_id) @@ -432,7 +432,7 @@ def test_handle_duplicate_edges_fixed_fixed(self): assert sorted(result) == sorted(expected_output) @pytest.mark.parametrize("_", range(5)) - def test_handle_two_inputs_in_cycle(self, _): + def test_handle_two_inputs_in_cycle(self, _): # noqa: PT019 edges = [ ("chat_input", "router"), ("chat_input", "concatenate"), diff --git a/src/backend/tests/unit/graph/test_graph.py b/src/backend/tests/unit/graph/test_graph.py index 3bb328daf431..af9f009cc5c1 100644 --- a/src/backend/tests/unit/graph/test_graph.py +++ b/src/backend/tests/unit/graph/test_graph.py @@ -79,8 +79,8 @@ def test_invalid_node_types(): ], "edges": [], } - with pytest.raises(Exception): - g = Graph() + g = Graph() + with pytest.raises(KeyError): g.add_nodes_and_edges(graph_data["nodes"], graph_data["edges"]) diff --git a/src/backend/tests/unit/helpers/test_base_model_from_schema.py b/src/backend/tests/unit/helpers/test_base_model_from_schema.py index 0a371a37b115..d07a4908e0a3 100644 --- a/src/backend/tests/unit/helpers/test_base_model_from_schema.py +++ b/src/backend/tests/unit/helpers/test_base_model_from_schema.py @@ -79,13 +79,13 @@ def test_manages_unknown_field_types(self): {"name": "field1", "type": "str", "default": "default_value1"}, {"name": "field2", "type": "unknown_type", "default": "default_value2"}, ] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid type: unknown_type"): build_model_from_schema(schema) # Confirms that the function raises a specific exception for invalid input def test_raises_error_for_invalid_input_different_exception_with_specific_exception(self): - with pytest.raises(ValueError): - schema = [{"name": "field1", "type": "invalid_type", "default": "default_value"}] + schema = [{"name": "field1", "type": "invalid_type", "default": "default_value"}] + with pytest.raises(ValueError, match="Invalid type: invalid_type"): build_model_from_schema(schema) # Processes schemas with missing optional keys like description or multiple diff --git a/src/backend/tests/unit/initial_setup/starter_projects/test_memory_chatbot.py b/src/backend/tests/unit/initial_setup/starter_projects/test_memory_chatbot.py index 59016abd6f85..9077e60ca88a 100644 --- a/src/backend/tests/unit/initial_setup/starter_projects/test_memory_chatbot.py +++ b/src/backend/tests/unit/initial_setup/starter_projects/test_memory_chatbot.py @@ -43,6 +43,7 @@ def memory_chatbot_graph(): return graph +@pytest.mark.usefixtures("client") def test_memory_chatbot(memory_chatbot_graph): # Now we run step by step expected_order = deque(["chat_input", "chat_memory", "prompt", "openai", "chat_output"]) diff --git a/src/backend/tests/unit/initial_setup/starter_projects/test_vector_store_rag.py b/src/backend/tests/unit/initial_setup/starter_projects/test_vector_store_rag.py index 0d383b431c51..c898e36b7873 100644 --- a/src/backend/tests/unit/initial_setup/starter_projects/test_vector_store_rag.py +++ b/src/backend/tests/unit/initial_setup/starter_projects/test_vector_store_rag.py @@ -34,7 +34,7 @@ def ingestion_graph(): embedding=openai_embeddings.build_embeddings, ingest_data=text_splitter.split_text, api_endpoint="https://astra.example.com", - token="token", + token="token", # noqa: S106 ) vector_store.set_on_output(name="vector_store", value="mock_vector_store", cache=True) vector_store.set_on_output(name="base_retriever", value="mock_retriever", cache=True) @@ -53,7 +53,7 @@ def rag_graph(): rag_vector_store.set( search_input=chat_input.message_response, api_endpoint="https://astra.example.com", - token="token", + token="token", # noqa: S106 embedding=openai_embeddings.build_embeddings, ) # Mock search_documents @@ -110,9 +110,7 @@ def test_vector_store_rag(ingestion_graph, rag_graph): "openai-embeddings-124", ] for ids, graph, len_results in [(ingestion_ids, ingestion_graph, 5), (rag_ids, rag_graph, 8)]: - results = [] - for result in graph.start(): - results.append(result) + results = list(graph.start()) assert len(results) == len_results vids = [result.vertex.id for result in results if hasattr(result, "vertex")] @@ -217,12 +215,14 @@ def test_vector_store_rag_add(ingestion_graph: Graph, rag_graph: Graph): rag_graph_copy = copy.deepcopy(rag_graph) ingestion_graph_copy += rag_graph_copy - assert ( - len(ingestion_graph_copy.vertices) == len(ingestion_graph.vertices) + len(rag_graph.vertices) - ), f"Vertices mismatch: {len(ingestion_graph_copy.vertices)} != {len(ingestion_graph.vertices)} + {len(rag_graph.vertices)}" - assert len(ingestion_graph_copy.edges) == len(ingestion_graph.edges) + len( - rag_graph.edges - ), f"Edges mismatch: {len(ingestion_graph_copy.edges)} != {len(ingestion_graph.edges)} + {len(rag_graph.edges)}" + assert len(ingestion_graph_copy.vertices) == len(ingestion_graph.vertices) + len(rag_graph.vertices), ( + f"Vertices mismatch: {len(ingestion_graph_copy.vertices)} " + f"!= {len(ingestion_graph.vertices)} + {len(rag_graph.vertices)}" + ) + assert len(ingestion_graph_copy.edges) == len(ingestion_graph.edges) + len(rag_graph.edges), ( + f"Edges mismatch: {len(ingestion_graph_copy.edges)} " + f"!= {len(ingestion_graph.edges)} + {len(rag_graph.edges)}" + ) combined_graph_dump = ingestion_graph_copy.dump( name="Combined Graph", description="Graph for data ingestion and RAG", endpoint_name="combined" diff --git a/src/backend/tests/unit/inputs/test_inputs.py b/src/backend/tests/unit/inputs/test_inputs.py index f45be41f6600..8d4386bcb738 100644 --- a/src/backend/tests/unit/inputs/test_inputs.py +++ b/src/backend/tests/unit/inputs/test_inputs.py @@ -70,7 +70,7 @@ def test_instantiate_input_valid(): def test_instantiate_input_invalid(): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid input type: InvalidInput"): instantiate_input("InvalidInput", {"name": "invalid_input", "value": "This is a string"}) @@ -224,5 +224,5 @@ def test_instantiate_input_comprehensive(): input_instance = instantiate_input(input_type, data) assert isinstance(input_instance, InputTypesMap[input_type]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Invalid input type: InvalidInput"): instantiate_input("InvalidInput", {"name": "invalid_input", "value": "Invalid"}) diff --git a/src/backend/tests/unit/io/test_io_schema.py b/src/backend/tests/unit/io/test_io_schema.py index cae508d3ff9c..456840875063 100644 --- a/src/backend/tests/unit/io/test_io_schema.py +++ b/src/backend/tests/unit/io/test_io_schema.py @@ -180,7 +180,7 @@ def test_is_list_handling(self): input_instance = StrInput(name="test_field", is_list=True) schema = create_input_schema([input_instance]) field_info = schema.model_fields["test_field"] - assert field_info.annotation == list[str] # type: ignore + assert field_info.annotation == list[str] # Converting FieldTypes to corresponding Python types def test_field_types_conversion(self): diff --git a/src/backend/tests/unit/io/test_table_schema.py b/src/backend/tests/unit/io/test_table_schema.py index e0e5d9f799b9..423731943098 100644 --- a/src/backend/tests/unit/io/test_table_schema.py +++ b/src/backend/tests/unit/io/test_table_schema.py @@ -33,7 +33,7 @@ def test_formatter_explicitly_set_to_enum(self): # Invalid formatter raises ValueError def test_invalid_formatter_raises_value_error(self): - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="'invalid' is not a valid FormatterType"): Column(display_name="Invalid Column", name="invalid_column", formatter="invalid") # Formatter is None when not provided diff --git a/src/backend/tests/unit/services/variable/test_service.py b/src/backend/tests/unit/services/variable/test_service.py index 7a6e7c563a26..081c7147f29b 100644 --- a/src/backend/tests/unit/services/variable/test_service.py +++ b/src/backend/tests/unit/services/variable/test_service.py @@ -72,19 +72,16 @@ def test_get_variable(service, session): assert result == value -def test_get_variable__ValueError(service, session): +def test_get_variable__valueerror(service, session): user_id = uuid4() name = "name" field = "" - with pytest.raises(ValueError) as exc: + with pytest.raises(ValueError, match=f"{name} variable not found."): service.get_variable(user_id, name, field, session) - assert name in str(exc.value) - assert "variable not found" in str(exc.value) - -def test_get_variable__TypeError(service, session): +def test_get_variable__typeerror(service, session): user_id = uuid4() name = "name" value = "value" @@ -142,17 +139,14 @@ def test_update_variable(service, session): assert isinstance(result.updated_at, datetime) -def test_update_variable__ValueError(service, session): +def test_update_variable__valueerror(service, session): user_id = uuid4() name = "name" value = "value" - with pytest.raises(ValueError) as exc: + with pytest.raises(ValueError, match=f"{name} variable not found."): service.update_variable(user_id, name, value, session=session) - assert name in str(exc.value) - assert "variable not found" in str(exc.value) - def test_update_variable_fields(service, session): user_id = uuid4() @@ -192,26 +186,21 @@ def test_delete_variable(service, session): service.create_variable(user_id, name, value, session=session) recovered = service.get_variable(user_id, name, field, session=session) service.delete_variable(user_id, name, session=session) - with pytest.raises(ValueError) as exc: + with pytest.raises(ValueError, match=f"{name} variable not found."): service.get_variable(user_id, name, field, session) assert recovered == value - assert name in str(exc.value) - assert "variable not found" in str(exc.value) -def test_delete_variable__ValueError(service, session): +def test_delete_variable__valueerror(service, session): user_id = uuid4() name = "name" - with pytest.raises(ValueError) as exc: + with pytest.raises(ValueError, match=f"{name} variable not found."): service.delete_variable(user_id, name, session=session) - assert name in str(exc.value) - assert "variable not found" in str(exc.value) - -def test_delete_varaible_by_id(service, session): +def test_delete_variable_by_id(service, session): user_id = uuid4() name = "name" value = "value" @@ -220,24 +209,19 @@ def test_delete_varaible_by_id(service, session): saved = service.create_variable(user_id, name, value, session=session) recovered = service.get_variable(user_id, name, field, session=session) service.delete_variable_by_id(user_id, saved.id, session=session) - with pytest.raises(ValueError) as exc: + with pytest.raises(ValueError, match=f"{name} variable not found."): service.get_variable(user_id, name, field, session) assert recovered == value - assert name in str(exc.value) - assert "variable not found" in str(exc.value) -def test_delete_variable_by_id__ValueError(service, session): +def test_delete_variable_by_id__valueerror(service, session): user_id = uuid4() variable_id = uuid4() - with pytest.raises(ValueError) as exc: + with pytest.raises(ValueError, match=f"{variable_id} variable not found."): service.delete_variable_by_id(user_id, variable_id, session=session) - assert str(variable_id) in str(exc.value) - assert "variable not found" in str(exc.value) - def test_create_variable(service, session): user_id = uuid4() diff --git a/src/backend/tests/unit/test_api_key.py b/src/backend/tests/unit/test_api_key.py index 50970e4e62d9..d3358f873486 100644 --- a/src/backend/tests/unit/test_api_key.py +++ b/src/backend/tests/unit/test_api_key.py @@ -4,7 +4,11 @@ @pytest.fixture -async def api_key(client, logged_in_headers, active_user): +async def api_key( + client, + logged_in_headers, + active_user, # noqa: ARG001 +): api_key = ApiKeyCreate(name="test-api-key") response = await client.post("api/v1/api_key/", data=api_key.model_dump_json(), headers=logged_in_headers) diff --git a/src/backend/tests/unit/test_custom_component.py b/src/backend/tests/unit/test_custom_component.py index 1c864a5fde1c..3985c5e0eefc 100644 --- a/src/backend/tests/unit/test_custom_component.py +++ b/src/backend/tests/unit/test_custom_component.py @@ -51,9 +51,7 @@ def test_code_parser_get_tree(): def test_code_parser_syntax_error(): - """Test the __get_tree method raises the - CodeSyntaxError when given incorrect syntax. - """ + """Test the __get_tree method raises the CodeSyntaxError when given incorrect syntax.""" code_syntax_error = "zzz import os" parser = CodeParser(code_syntax_error) @@ -76,9 +74,7 @@ def test_component_get_code_tree(): def test_component_code_null_error(): - """Test the get_function method raises the - ComponentCodeNullError when the code is empty. - """ + """Test the get_function method raises the ComponentCodeNullError when the code is empty.""" component = BaseComponent(_code="", _function_entrypoint_name="") with pytest.raises(ComponentCodeNullError): component.get_function() @@ -108,9 +104,7 @@ def test_custom_component_get_function(): def test_code_parser_parse_imports_import(): - """Test the parse_imports method of the CodeParser - class with an import statement. - """ + """Test the parse_imports method of the CodeParser class with an import statement.""" parser = CodeParser(code_default) tree = parser.get_tree() for node in ast.walk(tree): @@ -120,9 +114,7 @@ class with an import statement. def test_code_parser_parse_imports_importfrom(): - """Test the parse_imports method of the CodeParser - class with an import from statement. - """ + """Test the parse_imports method of the CodeParser class with an import from statement.""" parser = CodeParser("from os import path") tree = parser.get_tree() for node in ast.walk(tree): @@ -157,9 +149,9 @@ def test_code_parser_parse_classes_raises(): """Test the parse_classes method of the CodeParser class.""" parser = CodeParser("class Test: pass") tree = parser.get_tree() - with pytest.raises(TypeError): - for node in ast.walk(tree): - if isinstance(node, ast.ClassDef): + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + with pytest.raises(TypeError): parser.parse_classes(node) @@ -175,18 +167,14 @@ def test_code_parser_parse_global_vars(): def test_component_get_function_valid(): - """Test the get_function method of the Component - class with valid code and function_entrypoint_name. - """ + """Test the get_function method of the Component class with valid code and function_entrypoint_name.""" component = BaseComponent(_code="def build(): pass", _function_entrypoint_name="build") my_function = component.get_function() assert callable(my_function) def test_custom_component_get_function_entrypoint_args(): - """Test the get_function_entrypoint_args - property of the CustomComponent class. - """ + """Test the get_function_entrypoint_args property of the CustomComponent class.""" custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build") args = custom_component.get_function_entrypoint_args assert len(args) == 3 @@ -196,9 +184,7 @@ def test_custom_component_get_function_entrypoint_args(): def test_custom_component_get_function_entrypoint_return_type(): - """Test the get_function_entrypoint_return_type - property of the CustomComponent class. - """ + """Test the get_function_entrypoint_return_type property of the CustomComponent class.""" custom_component = CustomComponent(_code=code_default, _function_entrypoint_name="build") return_type = custom_component.get_function_entrypoint_return_type assert return_type == [Document] @@ -212,9 +198,7 @@ def test_custom_component_get_main_class_name(): def test_custom_component_get_function_valid(): - """Test the get_function property of the CustomComponent - class with valid code and function_entrypoint_name. - """ + """Test the get_function property of the CustomComponent class with valid code and function_entrypoint_name.""" custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build") my_function = custom_component.get_function assert callable(my_function) @@ -239,9 +223,7 @@ def test_code_parser_parse_arg_with_annotation(): def test_code_parser_parse_callable_details_no_args(): - """Test the parse_callable_details method of the - CodeParser class with a function with no arguments. - """ + """Test the parse_callable_details method of the CodeParser class with a function with no arguments.""" parser = CodeParser("") node = ast.FunctionDef( name="test", @@ -280,9 +262,7 @@ def test_code_parser_parse_ann_assign(): def test_code_parser_parse_function_def_not_init(): - """Test the parse_function_def method of the - CodeParser class with a function that is not __init__. - """ + """Test the parse_function_def method of the CodeParser class with a function that is not __init__.""" parser = CodeParser("") stmt = ast.FunctionDef( name="test", @@ -297,9 +277,7 @@ def test_code_parser_parse_function_def_not_init(): def test_code_parser_parse_function_def_init(): - """Test the parse_function_def method of the - CodeParser class with an __init__ function. - """ + """Test the parse_function_def method of the CodeParser class with an __init__ function.""" parser = CodeParser("") stmt = ast.FunctionDef( name="__init__", @@ -314,36 +292,28 @@ def test_code_parser_parse_function_def_init(): def test_component_get_code_tree_syntax_error(): - """Test the get_code_tree method of the Component class - raises the CodeSyntaxError when given incorrect syntax. - """ + """Test the get_code_tree method of the Component class raises the CodeSyntaxError when given incorrect syntax.""" component = BaseComponent(_code="import os as", _function_entrypoint_name="build") with pytest.raises(CodeSyntaxError): component.get_code_tree(component._code) def test_custom_component_class_template_validation_no_code(): - """Test the _class_template_validation method of the CustomComponent class - raises the HTTPException when the code is None. - """ + """Test CustomComponent._class_template_validation raises the HTTPException when the code is None.""" custom_component = CustomComponent(_code=None, _function_entrypoint_name="build") with pytest.raises(TypeError): custom_component.get_function() def test_custom_component_get_code_tree_syntax_error(): - """Test the get_code_tree method of the CustomComponent class - raises the CodeSyntaxError when given incorrect syntax. - """ + """Test CustomComponent.get_code_tree raises the CodeSyntaxError when given incorrect syntax.""" custom_component = CustomComponent(_code="import os as", _function_entrypoint_name="build") with pytest.raises(CodeSyntaxError): custom_component.get_code_tree(custom_component._code) def test_custom_component_get_function_entrypoint_args_no_args(): - """Test the get_function_entrypoint_args property of - the CustomComponent class with a build method with no arguments. - """ + """Test CustomComponent.get_function_entrypoint_args with a build method with no arguments.""" my_code = """ from langflow.custom import CustomComponent class MyMainClass(CustomComponent): @@ -356,9 +326,7 @@ def build(): def test_custom_component_get_function_entrypoint_return_type_no_return_type(): - """Test the get_function_entrypoint_return_type property of the - CustomComponent class with a build method with no return type. - """ + """Test CustomComponent.get_function_entrypoint_return_type with a build method with no return type.""" my_code = """ from langflow.custom import CustomComponent class MyClass(CustomComponent): @@ -371,9 +339,7 @@ def build(): def test_custom_component_get_main_class_name_no_main_class(): - """Test the get_main_class_name property of the - CustomComponent class when there is no main class. - """ + """Test the get_main_class_name property of the CustomComponent class when there is no main class.""" my_code = """ def build(): pass""" @@ -384,9 +350,7 @@ def build(): def test_custom_component_build_not_implemented(): - """Test the build method of the CustomComponent - class raises the NotImplementedError. - """ + """Test the build method of the CustomComponent class raises the NotImplementedError.""" custom_component = CustomComponent(_code="def build(): pass", _function_entrypoint_name="build") with pytest.raises(NotImplementedError): custom_component.build() diff --git a/src/backend/tests/unit/test_custom_component_with_client.py b/src/backend/tests/unit/test_custom_component_with_client.py index 16fc4116bda0..3c029be45ed3 100644 --- a/src/backend/tests/unit/test_custom_component_with_client.py +++ b/src/backend/tests/unit/test_custom_component_with_client.py @@ -15,7 +15,10 @@ def code_component_with_multiple_outputs(): @pytest.fixture -def component(client, active_user): +def component( + client, # noqa: ARG001 + active_user, +): return CustomComponent( user_id=active_user.id, field_config={ diff --git a/src/backend/tests/unit/test_database.py b/src/backend/tests/unit/test_database.py index 2a6c57146c34..d6cdf82b6952 100644 --- a/src/backend/tests/unit/test_database.py +++ b/src/backend/tests/unit/test_database.py @@ -1,5 +1,5 @@ import json -from collections import namedtuple +from typing import NamedTuple from uuid import UUID, uuid4 import orjson @@ -13,7 +13,6 @@ from langflow.services.database.models.folder.model import FolderCreate from langflow.services.database.utils import session_getter from langflow.services.deps import get_db_service -from sqlmodel import Session @pytest.fixture(scope="module") @@ -179,11 +178,11 @@ async def test_delete_flows_with_transaction_and_build(client: TestClient, logge assert response.status_code == 201 flow_ids.append(response.json()["id"]) - # Create a transaction for each flow + class VertexTuple(NamedTuple): + id: str + # Create a transaction for each flow for flow_id in flow_ids: - VertexTuple = namedtuple("VertexTuple", ["id"]) - await log_transaction( str(flow_id), source=VertexTuple(id="vid"), target=VertexTuple(id="tid"), status="success" ) @@ -249,10 +248,11 @@ async def test_delete_folder_with_flows_with_transaction_and_build(client: TestC assert response.status_code == 201 flow_ids.append(response.json()["id"]) + class VertexTuple(NamedTuple): + id: str + # Create a transaction for each flow for flow_id in flow_ids: - VertexTuple = namedtuple("VertexTuple", ["id"]) - await log_transaction( str(flow_id), source=VertexTuple(id="vid"), target=VertexTuple(id="tid"), status="success" ) @@ -400,9 +400,9 @@ async def test_upload_file(client: TestClient, json_flow: str, logged_in_headers assert response_data[1]["data"] == data +@pytest.mark.usefixtures("session") async def test_download_file( client: TestClient, - session: Session, json_flow, active_user, logged_in_headers, @@ -419,14 +419,14 @@ async def test_download_file( ] ) db_manager = get_db_service() - with session_getter(db_manager) as session: + with session_getter(db_manager) as _session: saved_flows = [] for flow in flow_list.flows: flow.user_id = active_user.id db_flow = Flow.model_validate(flow, from_attributes=True) - session.add(db_flow) + _session.add(db_flow) saved_flows.append(db_flow) - session.commit() + _session.commit() # Make request to endpoint inside the session context flow_ids = [str(db_flow.id) for db_flow in saved_flows] # Convert UUIDs to strings flow_ids_json = json.dumps(flow_ids) diff --git a/src/backend/tests/unit/test_endpoints.py b/src/backend/tests/unit/test_endpoints.py index 5236a823c4c3..bd207212e076 100644 --- a/src/backend/tests/unit/test_endpoints.py +++ b/src/backend/tests/unit/test_endpoints.py @@ -1,4 +1,4 @@ -import time +import asyncio from uuid import UUID, uuid4 import pytest @@ -27,7 +27,7 @@ async def poll_task_status(client, headers, href, max_attempts=20, sleep_time=1) ) if task_status_response.status_code == 200 and task_status_response.json()["status"] == "SUCCESS": return task_status_response.json() - time.sleep(sleep_time) + await asyncio.sleep(sleep_time) return None # Return None if task did not complete in time diff --git a/src/backend/tests/unit/test_files.py b/src/backend/tests/unit/test_files.py index e49af8eec38d..84eb7451165d 100644 --- a/src/backend/tests/unit/test_files.py +++ b/src/backend/tests/unit/test_files.py @@ -26,7 +26,13 @@ def mock_storage_service(): @pytest.fixture(name="files_client") -async def files_client_fixture(session: Session, monkeypatch, request, load_flows_dir, mock_storage_service): +async def files_client_fixture( + session: Session, # noqa: ARG001 + monkeypatch, + request, + load_flows_dir, + mock_storage_service, +): # Set the database url to a test database if "noclient" in request.keywords: yield @@ -47,9 +53,11 @@ async def files_client_fixture(session: Session, monkeypatch, request, load_flow app = create_app() app.dependency_overrides[get_storage_service] = lambda: mock_storage_service - async with LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager: - async with AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client: - yield client + async with ( + LifespanManager(app, startup_timeout=None, shutdown_timeout=None) as manager, + AsyncClient(transport=ASGITransport(app=manager.app), base_url="http://testserver/") as client, + ): + yield client # app.dependency_overrides.clear() monkeypatch.undo() # clear the temp db diff --git a/src/backend/tests/unit/test_initial_setup.py b/src/backend/tests/unit/test_initial_setup.py index 0a2f64e3c253..5236c1b0ba00 100644 --- a/src/backend/tests/unit/test_initial_setup.py +++ b/src/backend/tests/unit/test_initial_setup.py @@ -59,7 +59,8 @@ async def test_create_or_update_starter_projects(): assert folder is not None num_db_projects = len(folder.flows) - # Check that the number of projects in the database is the same as the number of projects returned by load_starter_projects + # Check that the number of projects in the database is the same as the number of projects returned by + # load_starter_projects assert num_db_projects == num_projects @@ -76,7 +77,8 @@ async def test_create_or_update_starter_projects(): # # Get the number of projects in the database # num_db_projects = session.exec(select(func.count(Flow.id)).where(Flow.folder == STARTER_FOLDER_NAME)).one() -# # Check that the number of projects in the database is the same as the number of projects returned by load_starter_projects +# # Check that the number of projects in the database is the same as the number of projects returned by +# # load_starter_projects # assert num_db_projects == num_projects # # Get all the starter projects @@ -99,7 +101,7 @@ async def test_create_or_update_starter_projects(): # delete_messages(session_id="test") -def find_componeny_by_name(components, name): +def find_component_by_name(components, name): for children in components.values(): if name in children: return children[name] @@ -111,17 +113,17 @@ def set_value(component, input_name, value): component["template"][input_name]["value"] = value -def component_to_node(id, type, component): - return {"id": type + id, "data": {"node": component, "type": type, "id": id}} +def component_to_node(node_id, node_type, component): + return {"id": node_type + node_id, "data": {"node": component, "type": node_type, "id": node_id}} -def add_edge(input, output, from_output, to_input): +def add_edge(source, target, from_output, to_input): return { - "source": input, - "target": output, + "source": source, + "target": target, "data": { - "sourceHandle": {"dataType": "ChatInput", "id": input, "name": from_output, "output_types": ["Message"]}, - "targetHandle": {"fieldName": to_input, "id": output, "inputTypes": ["Message"], "type": "str"}, + "sourceHandle": {"dataType": "ChatInput", "id": source, "name": from_output, "output_types": ["Message"]}, + "targetHandle": {"fieldName": to_input, "id": target, "inputTypes": ["Message"], "type": "str"}, }, } @@ -131,8 +133,8 @@ async def test_refresh_starter_projects(): data_path = str(Path(__file__).parent.parent.parent.absolute() / "base" / "langflow" / "components") components = build_custom_component_list_from_path(data_path) - chat_input = find_componeny_by_name(components, "ChatInput") - chat_output = find_componeny_by_name(components, "ChatOutput") + chat_input = find_component_by_name(components, "ChatInput") + chat_output = find_component_by_name(components, "ChatOutput") chat_output["template"]["code"]["value"] = "changed !" del chat_output["template"]["should_store_message"] graph_data = { diff --git a/src/backend/tests/unit/test_kubernetes_secrets.py b/src/backend/tests/unit/test_kubernetes_secrets.py index 7fcc52ed9bb5..9da44cdef82a 100644 --- a/src/backend/tests/unit/test_kubernetes_secrets.py +++ b/src/backend/tests/unit/test_kubernetes_secrets.py @@ -8,13 +8,13 @@ @pytest.fixture -def mock_kube_config(mocker): +def _mock_kube_config(mocker): mocker.patch("kubernetes.config.load_kube_config") mocker.patch("kubernetes.config.load_incluster_config") @pytest.fixture -def secret_manager(mock_kube_config): +def secret_manager(_mock_kube_config): return KubernetesSecretManager(namespace="test-namespace") diff --git a/src/backend/tests/unit/test_messages.py b/src/backend/tests/unit/test_messages.py index 59b100827aa9..fd8970710ac3 100644 --- a/src/backend/tests/unit/test_messages.py +++ b/src/backend/tests/unit/test_messages.py @@ -19,15 +19,15 @@ def created_message(): @pytest.fixture -def created_messages(session): - with session_scope() as session: +def created_messages(session): # noqa: ARG001 + with session_scope() as _session: messages = [ MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"), ] messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages] - messagetables = add_messagetables(messagetables, session) + messagetables = add_messagetables(messagetables, _session) return [MessageRead.model_validate(messagetable, from_attributes=True) for messagetable in messagetables] diff --git a/src/backend/tests/unit/test_messages_endpoints.py b/src/backend/tests/unit/test_messages_endpoints.py index c8c41de8ec0e..87f404c7fba5 100644 --- a/src/backend/tests/unit/test_messages_endpoints.py +++ b/src/backend/tests/unit/test_messages_endpoints.py @@ -20,15 +20,15 @@ async def created_message(): @pytest.fixture -def created_messages(session): - with session_scope() as session: +def created_messages(session): # noqa: ARG001 + with session_scope() as _session: messages = [ MessageCreate(text="Test message 1", sender="User", sender_name="User", session_id="session_id2"), MessageCreate(text="Test message 2", sender="User", sender_name="User", session_id="session_id2"), MessageCreate(text="Test message 3", sender="User", sender_name="User", session_id="session_id2"), ] messagetables = [MessageTable.model_validate(message, from_attributes=True) for message in messages] - return add_messagetables(messagetables, session) + return add_messagetables(messagetables, _session) @pytest.mark.api_key_required diff --git a/src/backend/tests/unit/test_process.py b/src/backend/tests/unit/test_process.py index 7858bf6276f7..d909b63e0451 100644 --- a/src/backend/tests/unit/test_process.py +++ b/src/backend/tests/unit/test_process.py @@ -282,12 +282,16 @@ async def test_load_langchain_object_with_cached_session(basic_graph_data): # session_service = get_session_service() # session_id1 = "non-existent-session-id" # session_id = session_service.build_key(session_id1, basic_graph_data) -# graph1, artifacts1 = await session_service.load_session(session_id, data_graph=basic_graph_data, flow_id="flow_id") +# graph1, artifacts1 = await session_service.load_session( +# session_id, data_graph=basic_graph_data, flow_id="flow_id" +# ) # # Clear the cache # await session_service.clear_session(session_id) # # Use the new session_id to get the graph again -# graph2, artifacts2 = await session_service.load_session(session_id, data_graph=basic_graph_data, flow_id="flow_id") - +# graph2, artifacts2 = await session_service.load_session( +# session_id, data_graph=basic_graph_data, flow_id="flow_id" +# ) +# # # Since the cache was cleared, objects should be different # assert id(graph1) != id(graph2) @@ -297,8 +301,12 @@ async def test_load_langchain_object_with_cached_session(basic_graph_data): # # Provide a non-existent session_id # session_service = get_session_service() # session_id1 = None -# graph1, artifacts1 = await session_service.load_session(session_id1, data_graph=basic_graph_data, flow_id="flow_id") +# graph1, artifacts1 = await session_service.load_session( +# session_id1, data_graph=basic_graph_data, flow_id="flow_id" +# ) # # Use the new session_id to get the langchain_object again -# graph2, artifacts2 = await session_service.load_session(session_id1, data_graph=basic_graph_data, flow_id="flow_id") - +# graph2, artifacts2 = await session_service.load_session( +# session_id1, data_graph=basic_graph_data, flow_id="flow_id" +# ) +# # assert graph1 == graph2 diff --git a/src/backend/tests/unit/test_schema.py b/src/backend/tests/unit/test_schema.py index 0938b78aed29..cce9d8bcd3df 100644 --- a/src/backend/tests/unit/test_schema.py +++ b/src/backend/tests/unit/test_schema.py @@ -45,52 +45,52 @@ def test_post_process_type_function(self): assert set(post_process_type(SequenceABC[float])) == {float} # Union types - assert set(post_process_type(Union[int, str])) == {int, str} - assert set(post_process_type(Union[int, SequenceABC[str]])) == {int, str} - assert set(post_process_type(Union[int, SequenceABC[int]])) == {int} + assert set(post_process_type(Union[int, str])) == {int, str} # noqa: UP007 + assert set(post_process_type(Union[int, SequenceABC[str]])) == {int, str} # noqa: UP007 + assert set(post_process_type(Union[int, SequenceABC[int]])) == {int} # noqa: UP007 # Nested Union with lists - assert set(post_process_type(Union[list[int], list[str]])) == {int, str} - assert set(post_process_type(Union[int, list[str], list[float]])) == {int, str, float} + assert set(post_process_type(Union[list[int], list[str]])) == {int, str} # noqa: UP007 + assert set(post_process_type(Union[int, list[str], list[float]])) == {int, str, float} # noqa: UP007 # Custom data types assert set(post_process_type(Data)) == {Data} assert set(post_process_type(list[Data])) == {Data} # Union with custom types - assert set(post_process_type(Union[Data, str])) == {Data, str} - assert set(post_process_type(Union[Data, int, list[str]])) == {Data, int, str} + assert set(post_process_type(Union[Data, str])) == {Data, str} # noqa: UP007 + assert set(post_process_type(Union[Data, int, list[str]])) == {Data, int, str} # noqa: UP007 # Empty lists and edge cases assert set(post_process_type(list)) == {list} - assert set(post_process_type(Union[int, None])) == {int, NoneType} - assert set(post_process_type(Union[None, list[None]])) == {None, NoneType} + assert set(post_process_type(Union[int, None])) == {int, NoneType} # noqa: UP007 + assert set(post_process_type(Union[None, list[None]])) == {None, NoneType} # noqa: UP007 # Handling complex nested structures - assert set(post_process_type(Union[SequenceABC[int | str], list[float]])) == {int, str, float} - assert set(post_process_type(Union[int | list[str] | list[float], str])) == {int, str, float} + assert set(post_process_type(Union[SequenceABC[int | str], list[float]])) == {int, str, float} # noqa: UP007 + assert set(post_process_type(Union[int | list[str] | list[float], str])) == {int, str, float} # noqa: UP007 # Non-generic types should return as is assert set(post_process_type(dict)) == {dict} assert set(post_process_type(tuple)) == {tuple} # Union with custom types - assert set(post_process_type(Union[Data, str])) == {Data, str} + assert set(post_process_type(Union[Data, str])) == {Data, str} # noqa: UP007 assert set(post_process_type(Data | str)) == {Data, str} assert set(post_process_type(Data | int | list[str])) == {Data, int, str} # More complex combinations with Data assert set(post_process_type(Data | list[float])) == {Data, float} - assert set(post_process_type(Data | Union[int, str])) == {Data, int, str} + assert set(post_process_type(Data | Union[int, str])) == {Data, int, str} # noqa: UP007 assert set(post_process_type(Data | list[int] | None)) == {Data, int, type(None)} - assert set(post_process_type(Data | Union[float, None])) == {Data, float, type(None)} + assert set(post_process_type(Data | Union[float, None])) == {Data, float, type(None)} # noqa: UP007 # Multiple Data types combined - assert set(post_process_type(Union[Data, str | float])) == {Data, str, float} - assert set(post_process_type(Union[Data | float | str, int])) == {Data, int, float, str} + assert set(post_process_type(Union[Data, str | float])) == {Data, str, float} # noqa: UP007 + assert set(post_process_type(Union[Data | float | str, int])) == {Data, int, float, str} # noqa: UP007 # Testing with nested unions and lists - assert set(post_process_type(Union[list[Data], list[int | str]])) == {Data, int, str} + assert set(post_process_type(Union[list[Data], list[int | str]])) == {Data, int, str} # noqa: UP007 assert set(post_process_type(Data | list[float | str])) == {Data, float, str} def test_input_to_dict(self): @@ -157,7 +157,7 @@ def test_list_int_type(self): assert post_process_type(list[int]) == [int] def test_union_type(self): - assert set(post_process_type(Union[int, str])) == {int, str} + assert set(post_process_type(Union[int, str])) == {int, str} # noqa: UP007 def test_custom_type(self): class CustomType: @@ -175,4 +175,4 @@ def test_union_custom_type(self): class CustomType: pass - assert set(post_process_type(Union[CustomType, int])) == {CustomType, int} + assert set(post_process_type(Union[CustomType, int])) == {CustomType, int} # noqa: UP007 diff --git a/src/backend/tests/unit/test_setup_superuser.py b/src/backend/tests/unit/test_setup_superuser.py index c2172429b92c..b8fb1cbd1309 100644 --- a/src/backend/tests/unit/test_setup_superuser.py +++ b/src/backend/tests/unit/test_setup_superuser.py @@ -112,11 +112,11 @@ def test_teardown_superuser_default_superuser(mock_get_session, mock_get_setting @patch("langflow.services.deps.get_settings_service") @patch("langflow.services.deps.get_session") def test_teardown_superuser_no_default_superuser(mock_get_session, mock_get_settings_service): - ADMIN_USER_NAME = "admin_user" + admin_user_name = "admin_user" mock_settings_service = MagicMock() mock_settings_service.auth_settings.AUTO_LOGIN = False - mock_settings_service.auth_settings.SUPERUSER = ADMIN_USER_NAME - mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password" + mock_settings_service.auth_settings.SUPERUSER = admin_user_name + mock_settings_service.auth_settings.SUPERUSER_PASSWORD = "password" # noqa: S105 mock_get_settings_service.return_value = mock_settings_service mock_session = MagicMock() diff --git a/src/backend/tests/unit/test_user.py b/src/backend/tests/unit/test_user.py index 3ea8c1655533..41fdaeabe45f 100644 --- a/src/backend/tests/unit/test_user.py +++ b/src/backend/tests/unit/test_user.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, timezone import pytest from httpx import AsyncClient @@ -11,7 +11,7 @@ @pytest.fixture -def super_user(client): +def super_user(client): # noqa: ARG001 settings_manager = get_settings_service() auth_settings = settings_manager.auth_settings with session_getter(get_db_service()) as session: @@ -23,7 +23,10 @@ def super_user(client): @pytest.fixture -async def super_user_headers(client: AsyncClient, super_user): +async def super_user_headers( + client: AsyncClient, + super_user, # noqa: ARG001 +): settings_service = get_settings_service() auth_settings = settings_service.auth_settings login_data = { @@ -38,14 +41,14 @@ async def super_user_headers(client: AsyncClient, super_user): @pytest.fixture -def deactivated_user(client): +def deactivated_user(client): # noqa: ARG001 with session_getter(get_db_service()) as session: user = User( username="deactivateduser", password=get_password_hash("testpassword"), is_active=False, is_superuser=False, - last_login_at=datetime.now(), + last_login_at=datetime.now(tz=timezone.utc), ) session.add(user) session.commit() @@ -55,7 +58,7 @@ def deactivated_user(client): async def test_user_waiting_for_approval(client): username = "waitingforapproval" - password = "testpassword" + password = "testpassword" # noqa: S105 # Debug: Check if the user already exists with session_getter(get_db_service()) as session: @@ -140,7 +143,7 @@ async def test_inactive_user(client: AsyncClient): username="inactiveuser", password=get_password_hash("testpassword"), is_active=False, - last_login_at=datetime(2023, 1, 1, 0, 0, 0), + last_login_at=datetime(2023, 1, 1, 0, 0, 0, tzinfo=timezone.utc), ) session.add(user) session.commit() @@ -205,7 +208,7 @@ async def test_patch_user(client: AsyncClient, active_user, logged_in_headers): async def test_patch_reset_password(client: AsyncClient, active_user, logged_in_headers): user_id = active_user.id update_data = UserUpdate( - password="newpassword", + password="newpassword", # noqa: S106 ) response = await client.patch( diff --git a/src/backend/tests/unit/test_webhook.py b/src/backend/tests/unit/test_webhook.py index 89a0e427e7bc..2c38c8fa7aa8 100644 --- a/src/backend/tests/unit/test_webhook.py +++ b/src/backend/tests/unit/test_webhook.py @@ -5,7 +5,7 @@ @pytest.fixture(autouse=True) -def check_openai_api_key_in_environment_variables(): +def _check_openai_api_key_in_environment_variables(): pass