Skip to content

Commit

Permalink
ref: Add ruff rules for arguments (ARG) (langflow-ai#4123)
Browse files Browse the repository at this point in the history
Add ruff rules for arguments (ARG)
  • Loading branch information
cbornet authored and smatiolids committed Oct 15, 2024
1 parent df784a0 commit eb34484
Show file tree
Hide file tree
Showing 49 changed files with 138 additions and 144 deletions.
16 changes: 8 additions & 8 deletions src/backend/base/langflow/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ def run(
),
log_level: str | None = typer.Option(None, help="Logging level.", show_default=False),
log_file: Path | None = typer.Option(None, help="Path to the log file.", show_default=False),
cache: str | None = typer.Option(
cache: str | None = typer.Option( # noqa: ARG001
None,
help="Type of cache to use. (InMemoryCache, SQLiteCache)",
show_default=False,
),
dev: bool | None = typer.Option(None, help="Run in development mode (may contain bugs)", show_default=False),
dev: bool | None = typer.Option(None, help="Run in development mode (may contain bugs)", show_default=False), # noqa: ARG001
frontend_path: str | None = typer.Option(
None,
help="Path to the frontend directory containing build files. This is for development purposes only.",
Expand All @@ -111,7 +111,7 @@ def run(
help="Open the browser after starting the server.",
show_default=False,
),
remove_api_keys: bool | None = typer.Option(
remove_api_keys: bool | None = typer.Option( # noqa: ARG001
None,
help="Remove API keys from the projects saved in the database.",
show_default=False,
Expand All @@ -121,27 +121,27 @@ def run(
help="Run only the backend server without the frontend.",
show_default=False,
),
store: bool | None = typer.Option(
store: bool | None = typer.Option( # noqa: ARG001
None,
help="Enables the store features.",
show_default=False,
),
auto_saving: bool | None = typer.Option(
auto_saving: bool | None = typer.Option( # noqa: ARG001
None,
help="Defines if the auto save is enabled.",
show_default=False,
),
auto_saving_interval: int | None = typer.Option(
auto_saving_interval: int | None = typer.Option( # noqa: ARG001
None,
help="Defines the debounce time for the auto save.",
show_default=False,
),
health_check_max_retries: bool | None = typer.Option(
health_check_max_retries: bool | None = typer.Option( # noqa: ARG001
None,
help="Defines the number of retries for the health check.",
show_default=False,
),
max_file_size_upload: int | None = typer.Option(
max_file_size_upload: int | None = typer.Option( # noqa: ARG001
None,
help="Defines the maximum file size for the upload in MB.",
show_default=False,
Expand Down
3 changes: 1 addition & 2 deletions src/backend/base/langflow/api/v1/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,9 @@ def create_api_key_route(
raise HTTPException(status_code=400, detail=str(e)) from e


@router.delete("/{api_key_id}")
@router.delete("/{api_key_id}", dependencies=[Depends(auth_utils.get_current_active_user)])
def delete_api_key_route(
api_key_id: UUID,
current_user=Depends(auth_utils.get_current_active_user),
db: Session = Depends(get_session),
):
try:
Expand Down
22 changes: 17 additions & 5 deletions src/backend/base/langflow/api/v1/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks.base import AsyncCallbackHandler
from loguru import logger
from typing_extensions import override

from langflow.api.v1.schemas import ChatResponse, PromptResponse
from langflow.services.deps import get_chat_service, get_socket_service
Expand All @@ -31,11 +32,13 @@ def __init__(self, session_id: str):
self.sid = session_id
# self.socketio_service = self.chat_service.active_connections[self.client_id]

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
@override
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: # type: ignore[misc]
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())

async def on_tool_start(self, serialized: dict[str, Any], input_str: str, **kwargs: Any) -> Any:
@override
async def on_tool_start(self, serialized: dict[str, Any], input_str: str, **kwargs: Any) -> Any: # type: ignore[misc]
"""Run when tool starts running."""
resp = ChatResponse(
message="",
Expand Down Expand Up @@ -88,7 +91,10 @@ async def on_tool_error(
) -> None:
"""Run when tool errors."""

async def on_text(self, text: str, **kwargs: Any) -> Any:
@override
async def on_text( # type: ignore[misc]
self, text: str, **kwargs: Any
) -> Any:
"""Run on arbitrary text."""
# This runs when first sending the prompt
# to the LLM, adding it will send the final prompt
Expand All @@ -101,7 +107,10 @@ async def on_text(self, text: str, **kwargs: Any) -> Any:
)
await self.socketio_service.emit_message(to=self.sid, data=resp.model_dump())

async def on_agent_action(self, action: AgentAction, **kwargs: Any):
@override
async def on_agent_action( # type: ignore[misc]
self, action: AgentAction, **kwargs: Any
):
log = f"Thought: {action.log}"
# if there are line breaks, split them and send them
# as separate messages
Expand All @@ -114,7 +123,10 @@ async def on_agent_action(self, action: AgentAction, **kwargs: Any):
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump())

async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
@override
async def on_agent_finish( # type: ignore[misc]
self, finish: AgentFinish, **kwargs: Any
) -> Any:
"""Run on agent end."""
resp = ChatResponse(
message="",
Expand Down
5 changes: 1 addition & 4 deletions src/backend/base/langflow/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,12 @@
from langflow.schema.schema import OutputValue
from langflow.services.auth.utils import get_current_active_user
from langflow.services.chat.service import ChatService
from langflow.services.deps import get_chat_service, get_session, get_session_service, get_telemetry_service
from langflow.services.deps import get_chat_service, get_session, get_telemetry_service
from langflow.services.telemetry.schema import ComponentPayload, PlaygroundPayload
from langflow.services.telemetry.service import TelemetryService

if TYPE_CHECKING:
from langflow.graph.vertex.types import InterfaceVertex
from langflow.services.session.service import SessionService

router = APIRouter(tags=["Chat"])

Expand Down Expand Up @@ -696,9 +695,7 @@ async def _stream_vertex(flow_id: str, vertex_id: str, chat_service: ChatService
async def build_vertex_stream(
flow_id: uuid.UUID,
vertex_id: str,
session_id: str | None = None,
chat_service: ChatService = Depends(get_chat_service),
session_service: SessionService = Depends(get_session_service),
):
"""Build a vertex instead of the entire graph.
Expand Down
25 changes: 3 additions & 22 deletions src/backend/base/langflow/api/v1/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import time
from asyncio import Lock
from http import HTTPStatus
from typing import TYPE_CHECKING, Annotated
from uuid import UUID
Expand Down Expand Up @@ -41,21 +40,18 @@
from langflow.services.database.models.flow.utils import get_all_webhook_components_in_flow
from langflow.services.database.models.user.model import User, UserRead
from langflow.services.deps import (
get_cache_service,
get_session,
get_session_service,
get_settings_service,
get_task_service,
get_telemetry_service,
)
from langflow.services.session.service import SessionService
from langflow.services.task.service import TaskService
from langflow.services.telemetry.schema import RunPayload
from langflow.services.telemetry.service import TelemetryService
from langflow.utils.version import get_version_info

if TYPE_CHECKING:
from langflow.services.cache.base import CacheService
from langflow.services.settings.service import SettingsService

router = APIRouter(tags=["Base"])
Expand All @@ -64,16 +60,11 @@
@router.get("/all", dependencies=[Depends(get_current_active_user)])
async def get_all(
settings_service=Depends(get_settings_service),
cache_service: CacheService = Depends(dependency=get_cache_service),
force_refresh: bool = False,
):
from langflow.interface.types import get_and_cache_all_types_dict

try:
async with Lock() as lock:
return await get_and_cache_all_types_dict(
settings_service=settings_service, cache_service=cache_service, force_refresh=force_refresh, lock=lock
)
return await get_and_cache_all_types_dict(settings_service=settings_service)

except Exception as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
Expand Down Expand Up @@ -496,19 +487,9 @@ async def experimental_run_flow(
@router.post(
"/process/{flow_id}",
response_model=ProcessResponse,
dependencies=[Depends(api_key_security)],
)
async def process(
session: Annotated[Session, Depends(get_session)],
flow_id: str,
inputs: list[dict] | dict | None = None,
tweaks: dict | None = None,
clear_cache: Annotated[bool, Body(embed=True)] = False,
session_id: Annotated[None | str, Body(embed=True)] = None,
task_service: TaskService = Depends(get_task_service),
api_key_user: UserRead = Depends(api_key_security),
sync: Annotated[bool, Body(embed=True)] = True,
session_service: SessionService = Depends(get_session_service),
):
async def process():
"""Endpoint to process an input with a given flow_id."""
# Raise a depreciation warning
logger.warning(
Expand Down
14 changes: 7 additions & 7 deletions src/backend/base/langflow/api/v1/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from langflow.services.database.models.message.model import MessageRead, MessageTable, MessageUpdate
from langflow.services.database.models.transactions.crud import get_transactions_by_flow_id
from langflow.services.database.models.transactions.model import TransactionReadResponse
from langflow.services.database.models.user.model import User
from langflow.services.database.models.vertex_builds.crud import (
delete_vertex_builds_by_flow_id,
get_vertex_builds_by_flow_id,
Expand Down Expand Up @@ -72,11 +71,10 @@ async def get_messages(
raise HTTPException(status_code=500, detail=str(e)) from e


@router.delete("/messages", status_code=204)
@router.delete("/messages", status_code=204, dependencies=[Depends(get_current_active_user)])
async def delete_messages(
message_ids: list[UUID],
session: Annotated[Session, Depends(get_session)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
try:
session.exec(delete(MessageTable).where(MessageTable.id.in_(message_ids))) # type: ignore[attr-defined]
Expand All @@ -85,12 +83,11 @@ async def delete_messages(
raise HTTPException(status_code=500, detail=str(e)) from e


@router.put("/messages/{message_id}", response_model=MessageRead)
@router.put("/messages/{message_id}", dependencies=[Depends(get_current_active_user)], response_model=MessageRead)
async def update_message(
message_id: UUID,
message: MessageUpdate,
session: Annotated[Session, Depends(get_session)],
user: Annotated[User, Depends(get_current_active_user)],
):
try:
db_message = session.get(MessageTable, message_id)
Expand All @@ -112,12 +109,15 @@ async def update_message(
return db_message


@router.patch("/messages/session/{old_session_id}", response_model=list[MessageResponse])
@router.patch(
"/messages/session/{old_session_id}",
dependencies=[Depends(get_current_active_user)],
response_model=list[MessageResponse],
)
async def update_session_id(
old_session_id: str,
new_session_id: Annotated[str, Query(..., description="The new session ID to update to")],
session: Annotated[Session, Depends(get_session)],
current_user: Annotated[User, Depends(get_current_active_user)],
):
try:
# Get all messages with the old session ID
Expand Down
8 changes: 2 additions & 6 deletions src/backend/base/langflow/api/v1/starter_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,12 @@

from langflow.graph.graph.schema import GraphDump
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.user.model import User

router = APIRouter(prefix="/starter-projects", tags=["Flows"])


@router.get("/", response_model=list[GraphDump], status_code=200)
def get_starter_projects(
*,
current_user: User = Depends(get_current_active_user),
):
@router.get("/", dependencies=[Depends(get_current_active_user)], response_model=list[GraphDump], status_code=200)
def get_starter_projects():
"""Get a list of starter projects."""
from langflow.initial_setup.load import get_starter_projects_dump

Expand Down
3 changes: 1 addition & 2 deletions src/backend/base/langflow/api/v1/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from langflow.services.auth.utils import get_current_active_user
from langflow.services.database.models.user.model import User
from langflow.services.database.models.variable import VariableCreate, VariableRead, VariableUpdate
from langflow.services.deps import get_session, get_settings_service, get_variable_service
from langflow.services.deps import get_session, get_variable_service
from langflow.services.variable.base import VariableService
from langflow.services.variable.constants import GENERIC_TYPE
from langflow.services.variable.service import DatabaseVariableService
Expand All @@ -21,7 +21,6 @@ def create_variable(
session: Session = Depends(get_session),
variable: VariableCreate,
current_user: User = Depends(get_current_active_user),
settings_service=Depends(get_settings_service),
variable_service: DatabaseVariableService = Depends(get_variable_service),
):
"""Create a new variable."""
Expand Down
8 changes: 4 additions & 4 deletions src/backend/base/langflow/base/agents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def validate_and_create_openai_tools_agent(
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: ChatPromptTemplate,
tools_renderer: Callable[[list[BaseTool]], str] = render_text_description,
_tools_renderer: Callable[[list[BaseTool]], str] = render_text_description,
*,
stop_sequence: bool | list[str] = True,
_stop_sequence: bool | list[str] = True,
):
return create_openai_tools_agent(
llm=llm,
Expand All @@ -83,9 +83,9 @@ def validate_and_create_tool_calling_agent(
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
prompt: ChatPromptTemplate,
tools_renderer: Callable[[list[BaseTool]], str] = render_text_description,
_tools_renderer: Callable[[list[BaseTool]], str] = render_text_description,
*,
stop_sequence: bool | list[str] = True,
_stop_sequence: bool | list[str] = True,
):
return create_tool_calling_agent(
llm=llm,
Expand Down
5 changes: 2 additions & 3 deletions src/backend/base/langflow/base/flow_processing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@ def build_data_from_run_outputs(run_outputs: RunOutputs) -> list[Data]:
return data


def build_data_from_result_data(result_data: ResultData, get_final_results_only: bool = True) -> list[Data]:
def build_data_from_result_data(result_data: ResultData) -> list[Data]:
"""Build a list of data from the given ResultData.
Args:
result_data (ResultData): The ResultData object containing the result data.
get_final_results_only (bool, optional): Whether to include only final results. Defaults to True.
Returns:
List[Data]: A list of data built from the ResultData.
Expand Down Expand Up @@ -64,7 +63,7 @@ def build_data_from_result_data(result_data: ResultData, get_final_results_only:

if isinstance(result_data.results, dict):
for name, result in result_data.results.items():
dataobj: Data | Message | None = None
dataobj: Data | Message | None
dataobj = result if isinstance(result, Message) else Data(data=result, text_key=name)

data.append(dataobj)
Expand Down
10 changes: 7 additions & 3 deletions src/backend/base/langflow/base/tools/flow_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from langchain_core.tools import BaseTool, ToolException
from loguru import logger
from typing_extensions import override

from langflow.base.flow_processing.utils import build_data_from_result_data, format_flow_output_data
from langflow.graph.graph.base import Graph # cannot be a part of TYPE_CHECKING # noqa: TCH001
Expand All @@ -30,7 +31,10 @@ def args(self) -> dict:
schema = self.get_input_schema()
return schema.schema()["properties"]

def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
@override
def get_input_schema( # type: ignore[misc]
self, config: RunnableConfig | None = None
) -> type[BaseModel]:
"""The tool's input schema."""
if self.args_schema is not None:
return self.args_schema
Expand Down Expand Up @@ -68,7 +72,7 @@ def _run(
if run_output is not None:
for output in run_output.outputs:
if output:
data.extend(build_data_from_result_data(output, get_final_results_only=self.get_final_results_only))
data.extend(build_data_from_result_data(output))
return format_flow_output_data(data)

def validate_inputs(self, args_names: list[dict[str, str]], args: Any, kwargs: Any):
Expand Down Expand Up @@ -118,5 +122,5 @@ async def _arun(
if run_output is not None:
for output in run_output.outputs:
if output:
data.extend(build_data_from_result_data(output, get_final_results_only=self.get_final_results_only))
data.extend(build_data_from_result_data(output))
return format_flow_output_data(data)
Loading

0 comments on commit eb34484

Please sign in to comment.