Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Activate ruff rules UP(pyupgrade) #3871

Merged
merged 6 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/backend/base/langflow/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import time
import warnings
from pathlib import Path
from typing import Optional

import click
import httpx
Expand Down Expand Up @@ -84,7 +83,7 @@ def run(
workers: int = typer.Option(1, help="Number of worker processes.", envvar="LANGFLOW_WORKERS"),
timeout: int = typer.Option(300, help="Worker timeout in seconds.", envvar="LANGFLOW_WORKER_TIMEOUT"),
port: int = typer.Option(7860, help="Port to listen on.", envvar="LANGFLOW_PORT"),
components_path: Optional[Path] = typer.Option(
components_path: Path | None = typer.Option(
Path(__file__).parent / "components",
help="Path to the directory containing custom components.",
envvar="LANGFLOW_COMPONENTS_PATH",
Expand All @@ -93,7 +92,7 @@ def run(
env_file: Path = typer.Option(None, help="Path to the .env file containing environment variables."),
log_level: str = typer.Option("critical", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"),
log_file: Path = typer.Option("logs/langflow.log", help="Path to the log file.", envvar="LANGFLOW_LOG_FILE"),
cache: Optional[str] = typer.Option(
cache: str | None = typer.Option(
envvar="LANGFLOW_LANGCHAIN_CACHE",
help="Type of cache to use. (InMemoryCache, SQLiteCache)",
default=None,
Expand Down Expand Up @@ -161,7 +160,7 @@ def run(
health_check_max_retries=health_check_max_retries,
)
# create path object if path is provided
static_files_dir: Optional[Path] = Path(path) if path else None
static_files_dir: Path | None = Path(path) if path else None
settings_service = get_settings_service()
settings_service.set("backend_only", backend_only)
app = setup_app(static_files_dir=static_files_dir, backend_only=backend_only)
Expand Down
8 changes: 5 additions & 3 deletions src/backend/base/langflow/api/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import uuid
import warnings
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -69,7 +71,7 @@ def build_input_keys_response(langchain_object, artifacts):
return input_keys_response


def validate_is_component(flows: list["Flow"]):
def validate_is_component(flows: list[Flow]):
for flow in flows:
if not flow.data or flow.is_component is not None:
continue
Expand Down Expand Up @@ -152,15 +154,15 @@ async def build_graph_from_db_no_cache(flow_id: str, session: Session):
return await build_graph_from_data(flow_id, flow.data, flow_name=flow.name, user_id=str(flow.user_id))


async def build_graph_from_db(flow_id: str, session: Session, chat_service: "ChatService"):
async def build_graph_from_db(flow_id: str, session: Session, chat_service: ChatService):
graph = await build_graph_from_db_no_cache(flow_id, session)
await chat_service.set_cache(flow_id, graph)
return graph


async def build_and_cache_graph_from_data(
flow_id: str,
chat_service: "ChatService",
chat_service: ChatService,
graph_data: dict,
): # -> Graph | Any:
"""Build and cache the graph."""
Expand Down
4 changes: 3 additions & 1 deletion src/backend/base/langflow/api/v1/callback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any
from uuid import UUID

Expand Down Expand Up @@ -25,7 +27,7 @@ def ignore_chain(self) -> bool:
def __init__(self, session_id: str):
self.chat_service = get_chat_service()
self.client_id = session_id
self.socketio_service: "SocketIOService" = get_socket_service()
self.socketio_service: SocketIOService = get_socket_service()
self.sid = session_id
# self.socketio_service = self.chat_service.active_connections[self.client_id]

Expand Down
30 changes: 16 additions & 14 deletions src/backend/base/langflow/api/v1/chat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import json
import time
Expand Down Expand Up @@ -72,9 +74,9 @@ async def retrieve_vertices_order(
data: Annotated[FlowDataRequest | None, Body(embed=True)] | None = None,
stop_component_id: str | None = None,
start_component_id: str | None = None,
chat_service: "ChatService" = Depends(get_chat_service),
chat_service: ChatService = Depends(get_chat_service),
session=Depends(get_session),
telemetry_service: "TelemetryService" = Depends(get_telemetry_service),
telemetry_service: TelemetryService = Depends(get_telemetry_service),
):
"""
Retrieve the vertices order for a given flow.
Expand Down Expand Up @@ -148,12 +150,12 @@ async def build_flow(
stop_component_id: str | None = None,
start_component_id: str | None = None,
log_builds: bool | None = True,
chat_service: "ChatService" = Depends(get_chat_service),
chat_service: ChatService = Depends(get_chat_service),
current_user=Depends(get_current_active_user),
telemetry_service: "TelemetryService" = Depends(get_telemetry_service),
telemetry_service: TelemetryService = Depends(get_telemetry_service),
session=Depends(get_session),
):
async def build_graph_and_get_order() -> tuple[list[str], list[str], "Graph"]:
async def build_graph_and_get_order() -> tuple[list[str], list[str], Graph]:
start_time = time.perf_counter()
components_count = None
try:
Expand Down Expand Up @@ -205,7 +207,7 @@ async def build_graph_and_get_order() -> tuple[list[str], list[str], "Graph"]:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc

async def _build_vertex(vertex_id: str, graph: "Graph", event_manager: "EventManager") -> VertexBuildResponse:
async def _build_vertex(vertex_id: str, graph: Graph, event_manager: EventManager) -> VertexBuildResponse:
flow_id_str = str(flow_id)

next_runnable_vertices = []
Expand Down Expand Up @@ -320,9 +322,9 @@ async def _build_vertex(vertex_id: str, graph: "Graph", event_manager: "EventMan

async def build_vertices(
vertex_id: str,
graph: "Graph",
graph: Graph,
client_consumed_queue: asyncio.Queue,
event_manager: "EventManager",
event_manager: EventManager,
) -> None:
build_task = asyncio.create_task(asyncio.to_thread(asyncio.run, _build_vertex(vertex_id, graph, event_manager)))
try:
Expand Down Expand Up @@ -457,9 +459,9 @@ async def build_vertex(
background_tasks: BackgroundTasks,
inputs: Annotated[InputValueRequest | None, Body(embed=True)] = None,
files: list[str] | None = None,
chat_service: "ChatService" = Depends(get_chat_service),
chat_service: ChatService = Depends(get_chat_service),
current_user=Depends(get_current_active_user),
telemetry_service: "TelemetryService" = Depends(get_telemetry_service),
telemetry_service: TelemetryService = Depends(get_telemetry_service),
):
"""Build a vertex instead of the entire graph.

Expand Down Expand Up @@ -489,7 +491,7 @@ async def build_vertex(
if not cache:
# If there's no cache
logger.warning(f"No cache found for {flow_id_str}. Building graph starting at {vertex_id}")
graph: "Graph" = await build_graph_from_db(
graph: Graph = await build_graph_from_db(
flow_id=flow_id_str, session=next(get_session()), chat_service=chat_service
)
else:
Expand Down Expand Up @@ -609,8 +611,8 @@ async def build_vertex_stream(
flow_id: uuid.UUID,
vertex_id: str,
session_id: str | None = None,
chat_service: "ChatService" = Depends(get_chat_service),
session_service: "SessionService" = Depends(get_session_service),
chat_service: ChatService = Depends(get_chat_service),
session_service: SessionService = Depends(get_session_service),
):
"""Build a vertex instead of the entire graph.

Expand Down Expand Up @@ -650,7 +652,7 @@ async def stream_vertex():
else:
graph = cache.get("result")

vertex: "InterfaceVertex" = graph.get_vertex(vertex_id)
vertex: InterfaceVertex = graph.get_vertex(vertex_id)
if not hasattr(vertex, "stream"):
raise ValueError(f"Vertex {vertex_id} does not support streaming")
if isinstance(vertex._built_result, str) and vertex._built_result:
Expand Down
12 changes: 7 additions & 5 deletions src/backend/base/langflow/api/v1/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import time
from asyncio import Lock
from http import HTTPStatus
Expand Down Expand Up @@ -62,7 +64,7 @@
@router.get("/all", dependencies=[Depends(get_current_active_user)])
async def get_all(
settings_service=Depends(get_settings_service),
cache_service: "CacheService" = Depends(dependency=get_cache_service),
cache_service: CacheService = Depends(dependency=get_cache_service),
force_refresh: bool = False,
):
from langflow.interface.types import get_and_cache_all_types_dict
Expand Down Expand Up @@ -181,7 +183,7 @@ async def simplified_run_flow(
input_request: SimplifiedAPIRequest = SimplifiedAPIRequest(),
stream: bool = False,
api_key_user: UserRead = Depends(api_key_security),
telemetry_service: "TelemetryService" = Depends(get_telemetry_service),
telemetry_service: TelemetryService = Depends(get_telemetry_service),
):
"""
Executes a specified flow by ID with input customization, performance enhancements through caching, and optional data streaming.
Expand Down Expand Up @@ -290,7 +292,7 @@ async def webhook_run_flow(
user: Annotated[User, Depends(get_user_by_flow_id_or_endpoint_name)],
request: Request,
background_tasks: BackgroundTasks,
telemetry_service: "TelemetryService" = Depends(get_telemetry_service),
telemetry_service: TelemetryService = Depends(get_telemetry_service),
):
"""
Run a flow using a webhook request.
Expand Down Expand Up @@ -484,7 +486,7 @@ async def process(
tweaks: dict | None = None,
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
session_id: Annotated[None | str, Body(embed=True)] = None, # noqa: F821
task_service: "TaskService" = Depends(get_task_service),
task_service: TaskService = Depends(get_task_service),
api_key_user: UserRead = Depends(api_key_security),
sync: Annotated[bool, Body(embed=True)] = True, # noqa: F821
session_service: SessionService = Depends(get_session_service),
Expand Down Expand Up @@ -633,7 +635,7 @@ def get_config():
try:
from langflow.services.deps import get_settings_service

settings_service: "SettingsService" = get_settings_service() # type: ignore
settings_service: SettingsService = get_settings_service() # type: ignore
return settings_service.settings.model_dump()
except Exception as exc:
logger.exception(exc)
Expand Down
8 changes: 4 additions & 4 deletions src/backend/base/langflow/base/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import List, Optional, Union, cast
from typing import cast

from langchain.agents import AgentExecutor, BaseMultiActionAgent, BaseSingleActionAgent
from langchain.agents.agent import RunnableAgent
Expand All @@ -20,7 +20,7 @@

class LCAgentComponent(Component):
trace_type = "agent"
_base_inputs: List[InputTypes] = [
_base_inputs: list[InputTypes] = [
MessageTextInput(name="input_value", display_name="Input"),
BoolInput(
name="handle_parsing_errors",
Expand Down Expand Up @@ -89,7 +89,7 @@ def get_agent_kwargs(self, flatten: bool = False) -> dict:
}
return {**base, "agent_executor_kwargs": agent_kwargs}

def get_chat_history_data(self) -> Optional[List[Data]]:
def get_chat_history_data(self) -> list[Data] | None:
# might be overridden in subclasses
return None

Expand Down Expand Up @@ -128,7 +128,7 @@ def build_agent(self) -> AgentExecutor:

async def run_agent(
self,
agent: Union[Runnable, BaseSingleActionAgent, BaseMultiActionAgent, AgentExecutor],
agent: Runnable | BaseSingleActionAgent | BaseMultiActionAgent | AgentExecutor,
) -> Text:
if isinstance(agent, AgentExecutor):
runnable = agent
Expand Down
14 changes: 7 additions & 7 deletions src/backend/base/langflow/base/agents/callback.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any
from uuid import UUID

from langchain.callbacks.base import AsyncCallbackHandler
Expand All @@ -15,14 +15,14 @@ def __init__(self, log_function: LogFunctionType | None = None):

async def on_tool_start(
self,
serialized: Dict[str, Any],
serialized: dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: List[str] | None = None,
metadata: Dict[str, Any] | None = None,
inputs: Dict[str, Any] | None = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
inputs: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
if self.log_function is None:
Expand Down Expand Up @@ -62,7 +62,7 @@ async def on_agent_action(
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: List[str] | None = None,
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
if self.log_function is None:
Expand All @@ -85,7 +85,7 @@ async def on_agent_finish(
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: List[str] | None = None,
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
if self.log_function is None:
Expand Down
5 changes: 3 additions & 2 deletions src/backend/base/langflow/base/agents/crewai/crew.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Callable, List, Tuple, Union, cast
from collections.abc import Callable
from typing import cast

from crewai import Agent, Crew, Process, Task # type: ignore
from crewai.task import TaskOutput # type: ignore
Expand Down Expand Up @@ -62,7 +63,7 @@ def task_callback(task_output: TaskOutput):
def get_step_callback(
self,
) -> Callable:
def step_callback(agent_output: Union[AgentFinish, List[Tuple[AgentAction, str]]]):
def step_callback(agent_output: AgentFinish | list[tuple[AgentAction, str]]):
_id = self._vertex.id if self._vertex else self.display_name
if isinstance(agent_output, AgentFinish):
messages = agent_output.messages
Expand Down
Loading
Loading