diff --git a/src/backend/base/langflow/__main__.py b/src/backend/base/langflow/__main__.py index 2dc231b6a876..437995c1d9cf 100644 --- a/src/backend/base/langflow/__main__.py +++ b/src/backend/base/langflow/__main__.py @@ -45,7 +45,7 @@ def get_number_of_workers(workers=None): return workers -def display_results(results): +def display_results(results) -> None: """Display the results of the migration.""" for table_results in results: table = Table(title=f"Migration {table_results.table_name}") @@ -62,7 +62,7 @@ def display_results(results): console.print() # Print a new line -def set_var_for_macos_issue(): +def set_var_for_macos_issue() -> None: # OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES # we need to set this var is we are running on MacOS # otherwise we get an error when running gunicorn @@ -146,7 +146,7 @@ def run( help="Defines the maximum file size for the upload in MB.", show_default=False, ), -): +) -> None: """Run Langflow.""" configure(log_level=log_level, log_file=log_file) set_var_for_macos_issue() @@ -202,7 +202,7 @@ def run( # Run using uvicorn on MacOS and Windows # Windows doesn't support gunicorn # MacOS requires an env variable to be set to use gunicorn - process = run_on_windows(host, port, log_level, options, app) + run_on_windows(host, port, log_level, options, app) else: # Run using gunicorn on Linux process = run_on_mac_or_linux(host, port, log_level, options, app) @@ -219,7 +219,7 @@ def run( sys.exit(1) -def wait_for_server_ready(host, port): +def wait_for_server_ready(host, port) -> None: """Wait for the server to become ready by polling the health endpoint.""" status_code = 0 while status_code != httpx.codes.OK: @@ -241,7 +241,7 @@ def run_on_mac_or_linux(host, port, log_level, options, app): return webapp_process -def run_on_windows(host, port, log_level, options, app): +def run_on_windows(host, port, log_level, options, app) -> None: """Run the Langflow server on Windows.""" print_banner(host, port) run_langflow(host, port, log_level, options, app) @@ -275,7 +275,7 @@ def get_free_port(port): return port -def get_letter_from_version(version: str): +def get_letter_from_version(version: str) -> str | None: """Get the letter from a pre-release version.""" if "a" in version: return "a" @@ -294,7 +294,7 @@ def build_version_notice(current_version: str, package_name: str) -> str: return "" -def generate_pip_command(package_names, is_pre_release): +def generate_pip_command(package_names, is_pre_release) -> str: """Generate the pip install command based on the packages and whether it's a pre-release.""" base_command = "pip install" if is_pre_release: @@ -309,7 +309,7 @@ def stylize_text(text: str, to_style: str, *, is_prerelease: bool) -> str: return text.replace(to_style, styled_text) -def print_banner(host: str, port: int): +def print_banner(host: str, port: int) -> None: notices = [] package_names = [] # Track package names for pip install instructions is_pre_release = False # Track if any package is a pre-release @@ -355,7 +355,7 @@ def print_banner(host: str, port: int): rprint(panel) -def run_langflow(host, port, log_level, options, app): +def run_langflow(host, port, log_level, options, app) -> None: """Run Langflow server on localhost.""" if platform.system() == "Windows": # Run using uvicorn on MacOS and Windows @@ -381,7 +381,7 @@ def superuser( username: str = typer.Option(..., prompt=True, help="Username for the superuser."), password: str = typer.Option(..., prompt=True, hide_input=True, help="Password for the superuser."), log_level: str = typer.Option("error", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"), -): +) -> None: """Create a superuser.""" configure(log_level=log_level) initialize_services() @@ -413,7 +413,7 @@ def superuser( # command to copy the langflow database from the cache to the current directory # because now the database is stored per installation @app.command() -def copy_db(): +def copy_db() -> None: """Copy the database files to the current directory. This function copies the 'langflow.db' and 'langflow-pre.db' files from the cache directory to the current @@ -452,7 +452,7 @@ def migration( default=False, help="Fix migrations. This is a destructive operation, and should only be used if you know what you are doing.", ), -): +) -> None: """Run or test migrations.""" if fix and not typer.confirm( "This will delete all data necessary to fix migrations. Are you sure you want to continue?" @@ -470,7 +470,7 @@ def migration( @app.command() def api_key( log_level: str = typer.Option("error", help="Logging level."), -): +) -> None: """Creates an API key for the default superuser if AUTO_LOGIN is enabled. Args: @@ -510,7 +510,7 @@ def api_key( api_key_banner(unmasked_api_key) -def api_key_banner(unmasked_api_key): +def api_key_banner(unmasked_api_key) -> None: is_mac = platform.system() == "Darwin" import pyperclip @@ -529,7 +529,7 @@ def api_key_banner(unmasked_api_key): console.print(panel) -def main(): +def main() -> None: with warnings.catch_warnings(): warnings.simplefilter("ignore") app() diff --git a/src/backend/base/langflow/api/utils.py b/src/backend/base/langflow/api/utils.py index e2428a8d2141..b8aa21af594a 100644 --- a/src/backend/base/langflow/api/utils.py +++ b/src/backend/base/langflow/api/utils.py @@ -93,7 +93,7 @@ def get_is_component_from_data(data: dict): return data.get("is_component") -async def check_langflow_version(component: StoreComponentCreate): +async def check_langflow_version(component: StoreComponentCreate) -> None: from langflow.utils.version import get_version_info __version__ = get_version_info()["version"] @@ -259,7 +259,7 @@ def parse_value(value: Any, input_type: str) -> Any: return value -async def cascade_delete_flow(session: Session, flow: Flow): +async def cascade_delete_flow(session: Session, flow: Flow) -> None: try: session.exec(delete(TransactionTable).where(TransactionTable.flow_id == flow.id)) session.exec(delete(VertexBuildTable).where(VertexBuildTable.flow_id == flow.id)) diff --git a/src/backend/base/langflow/api/v1/callback.py b/src/backend/base/langflow/api/v1/callback.py index 2049cd260e5b..527241a64eb3 100644 --- a/src/backend/base/langflow/api/v1/callback.py +++ b/src/backend/base/langflow/api/v1/callback.py @@ -110,7 +110,7 @@ async def on_text( # type: ignore[misc] @override async def on_agent_action( # type: ignore[misc] self, action: AgentAction, **kwargs: Any - ): + ) -> None: log = f"Thought: {action.log}" # if there are line breaks, split them and send them # as separate messages diff --git a/src/backend/base/langflow/api/v1/chat.py b/src/backend/base/langflow/api/v1/chat.py index a6fd37a5669d..2b9f403564d6 100644 --- a/src/backend/base/langflow/api/v1/chat.py +++ b/src/backend/base/langflow/api/v1/chat.py @@ -419,7 +419,7 @@ async def consume_and_yield(queue: asyncio.Queue, client_consumed_queue: asyncio event_manager = create_default_event_manager(queue=asyncio_queue) main_task = asyncio.create_task(event_generator(event_manager, asyncio_queue_client_consumed)) - def on_disconnect(): + def on_disconnect() -> None: logger.debug("Client disconnected, closing tasks") main_task.cancel() diff --git a/src/backend/base/langflow/api/v1/endpoints.py b/src/backend/base/langflow/api/v1/endpoints.py index 456ebd4c7338..dbe496d1e183 100644 --- a/src/backend/base/langflow/api/v1/endpoints.py +++ b/src/backend/base/langflow/api/v1/endpoints.py @@ -16,7 +16,6 @@ CustomComponentRequest, CustomComponentResponse, InputValueRequest, - ProcessResponse, RunResponse, SidebarCategoriesResponse, SimplifiedAPIRequest, @@ -68,7 +67,7 @@ async def get_all(): raise HTTPException(status_code=500, detail=str(exc)) from exc -def validate_input_and_tweaks(input_request: SimplifiedAPIRequest): +def validate_input_and_tweaks(input_request: SimplifiedAPIRequest) -> None: # If the input_value is not None and the input_type is "chat" # then we need to check the tweaks if the ChatInput component is present # and if its input_value is not None @@ -483,15 +482,13 @@ async def experimental_run_flow( @router.post( "/predict/{flow_id}", - response_model=ProcessResponse, dependencies=[Depends(api_key_security)], ) @router.post( "/process/{flow_id}", - response_model=ProcessResponse, dependencies=[Depends(api_key_security)], ) -async def process(): +async def process() -> None: """Endpoint to process an input with a given flow_id.""" # Raise a depreciation warning logger.warning( diff --git a/src/backend/base/langflow/api/v1/monitor.py b/src/backend/base/langflow/api/v1/monitor.py index b6bddc779929..b1c964d78fe2 100644 --- a/src/backend/base/langflow/api/v1/monitor.py +++ b/src/backend/base/langflow/api/v1/monitor.py @@ -36,7 +36,7 @@ async def get_vertex_builds( async def delete_vertex_builds( flow_id: Annotated[UUID, Query()], session: Annotated[Session, Depends(get_session)], -): +) -> None: try: delete_vertex_builds_by_flow_id(session, flow_id) except Exception as e: @@ -75,7 +75,7 @@ async def get_messages( async def delete_messages( message_ids: list[UUID], session: Annotated[Session, Depends(get_session)], -): +) -> None: try: session.exec(delete(MessageTable).where(MessageTable.id.in_(message_ids))) # type: ignore[attr-defined] session.commit() diff --git a/src/backend/base/langflow/api/v1/variable.py b/src/backend/base/langflow/api/v1/variable.py index c77ee2abdc2a..7293c1859e00 100644 --- a/src/backend/base/langflow/api/v1/variable.py +++ b/src/backend/base/langflow/api/v1/variable.py @@ -95,7 +95,7 @@ def delete_variable( variable_id: UUID, current_user: User = Depends(get_current_active_user), variable_service: VariableService = Depends(get_variable_service), -): +) -> None: """Delete a variable.""" try: variable_service.delete_variable_by_id(user_id=current_user.id, variable_id=variable_id, session=session) diff --git a/src/backend/base/langflow/base/agents/agent.py b/src/backend/base/langflow/base/agents/agent.py index 95098d6e616b..7219c6385265 100644 --- a/src/backend/base/langflow/base/agents/agent.py +++ b/src/backend/base/langflow/base/agents/agent.py @@ -64,7 +64,7 @@ async def message_response(self) -> Message: self.status = message return message - def _validate_outputs(self): + def _validate_outputs(self) -> None: required_output_methods = ["build_agent"] output_names = [output.name for output in self.outputs] for method_name in required_output_methods: diff --git a/src/backend/base/langflow/base/agents/crewai/crew.py b/src/backend/base/langflow/base/agents/crewai/crew.py index 5ceb538feb17..c2cc019862c7 100644 --- a/src/backend/base/langflow/base/agents/crewai/crew.py +++ b/src/backend/base/langflow/base/agents/crewai/crew.py @@ -52,7 +52,7 @@ def build_crew(self) -> Crew: def get_task_callback( self, ) -> Callable: - def task_callback(task_output: TaskOutput): + def task_callback(task_output: TaskOutput) -> None: vertex_id = self._vertex.id if self._vertex else self.display_name or self.__class__.__name__ self.log(task_output.model_dump(), name=f"Task (Agent: {task_output.agent}) - {vertex_id}") @@ -61,7 +61,7 @@ def task_callback(task_output: TaskOutput): def get_step_callback( self, ) -> Callable: - def step_callback(agent_output: AgentFinish | list[tuple[AgentAction, str]]): + def step_callback(agent_output: AgentFinish | list[tuple[AgentAction, str]]) -> None: _id = self._vertex.id if self._vertex else self.display_name if isinstance(agent_output, AgentFinish): messages = agent_output.messages diff --git a/src/backend/base/langflow/base/astra_assistants/util.py b/src/backend/base/langflow/base/astra_assistants/util.py index ee5f2e75dd77..5973da87c67d 100644 --- a/src/backend/base/langflow/base/astra_assistants/util.py +++ b/src/backend/base/langflow/base/astra_assistants/util.py @@ -36,7 +36,7 @@ def get_patched_openai_client(shared_component_cache): tools_and_names = {} -def tools_from_package(your_package): +def tools_from_package(your_package) -> None: # Iterate over all modules in the package package_name = your_package.__name__ for module_info in pkgutil.iter_modules(your_package.__path__): diff --git a/src/backend/base/langflow/base/chains/model.py b/src/backend/base/langflow/base/chains/model.py index 2f907450c01d..4f9d190f5b35 100644 --- a/src/backend/base/langflow/base/chains/model.py +++ b/src/backend/base/langflow/base/chains/model.py @@ -7,7 +7,7 @@ class LCChainComponent(Component): outputs = [Output(display_name="Text", name="text", method="invoke_chain")] - def _validate_outputs(self): + def _validate_outputs(self) -> None: required_output_methods = ["invoke_chain"] output_names = [output.name for output in self.outputs] for method_name in required_output_methods: diff --git a/src/backend/base/langflow/base/embeddings/model.py b/src/backend/base/langflow/base/embeddings/model.py index cb16b170a12e..182762bfbdc5 100644 --- a/src/backend/base/langflow/base/embeddings/model.py +++ b/src/backend/base/langflow/base/embeddings/model.py @@ -10,7 +10,7 @@ class LCEmbeddingsModel(Component): Output(display_name="Embeddings", name="embeddings", method="build_embeddings"), ] - def _validate_outputs(self): + def _validate_outputs(self) -> None: required_output_methods = ["build_embeddings"] output_names = [output.name for output in self.outputs] for method_name in required_output_methods: diff --git a/src/backend/base/langflow/base/io/chat.py b/src/backend/base/langflow/base/io/chat.py index ce21413bf423..5945bf825e3b 100644 --- a/src/backend/base/langflow/base/io/chat.py +++ b/src/backend/base/langflow/base/io/chat.py @@ -29,7 +29,7 @@ def store_message(self, message: Message) -> Message: self.status = stored_message return stored_message - def _send_message_event(self, message: Message): + def _send_message_event(self, message: Message) -> None: if hasattr(self, "_event_manager") and self._event_manager: self._event_manager.on_message(data=message.data) @@ -107,7 +107,7 @@ def _create_message(self, input_value, sender, sender_name, files, session_id) - return Message.from_data(input_value) return Message(text=input_value, sender=sender, sender_name=sender_name, files=files, session_id=session_id) - def _send_messages_events(self, messages): + def _send_messages_events(self, messages) -> None: if hasattr(self, "_event_manager") and self._event_manager: for stored_message in messages: self._event_manager.on_message(data=stored_message.data) diff --git a/src/backend/base/langflow/base/langchain_utilities/model.py b/src/backend/base/langflow/base/langchain_utilities/model.py index ed6562f17163..b8fca27510b1 100644 --- a/src/backend/base/langflow/base/langchain_utilities/model.py +++ b/src/backend/base/langflow/base/langchain_utilities/model.py @@ -14,7 +14,7 @@ class LCToolComponent(Component): Output(name="api_build_tool", display_name="Tool", method="build_tool"), ] - def _validate_outputs(self): + def _validate_outputs(self) -> None: required_output_methods = ["run_model", "build_tool"] output_names = [output.name for output in self.outputs] for method_name in required_output_methods: diff --git a/src/backend/base/langflow/base/memory/memory.py b/src/backend/base/langflow/base/memory/memory.py index b0b73f84ff10..2b6f0c59bec2 100644 --- a/src/backend/base/langflow/base/memory/memory.py +++ b/src/backend/base/langflow/base/memory/memory.py @@ -45,5 +45,5 @@ def get_messages(self, **kwargs) -> list[Data]: def add_message( self, sender: str, sender_name: str, text: str, session_id: str, metadata: dict | None = None, **kwargs - ): + ) -> None: raise NotImplementedError diff --git a/src/backend/base/langflow/base/memory/model.py b/src/backend/base/langflow/base/memory/model.py index 6ccac29f85d5..a33c7894ac86 100644 --- a/src/backend/base/langflow/base/memory/model.py +++ b/src/backend/base/langflow/base/memory/model.py @@ -17,7 +17,7 @@ class LCChatMemoryComponent(Component): ) ] - def _validate_outputs(self): + def _validate_outputs(self) -> None: required_output_methods = ["build_message_history"] output_names = [output.name for output in self.outputs] for method_name in required_output_methods: diff --git a/src/backend/base/langflow/base/models/model.py b/src/backend/base/langflow/base/models/model.py index 8d0b9eb4be1d..f1ed9885def8 100644 --- a/src/backend/base/langflow/base/models/model.py +++ b/src/backend/base/langflow/base/models/model.py @@ -42,7 +42,7 @@ class LCModelComponent(Component): def _get_exception_message(self, e: Exception): return str(e) - def _validate_outputs(self): + def _validate_outputs(self) -> None: # At least these two outputs must be defined required_output_methods = ["text_response", "build_model"] output_names = [output.name for output in self.outputs] diff --git a/src/backend/base/langflow/base/prompts/api_utils.py b/src/backend/base/langflow/base/prompts/api_utils.py index 71eb52031440..011135b3415b 100644 --- a/src/backend/base/langflow/base/prompts/api_utils.py +++ b/src/backend/base/langflow/base/prompts/api_utils.py @@ -86,7 +86,7 @@ def _check_variable(var, invalid_chars, wrong_variables, empty_variables): return wrong_variables, empty_variables -def _check_for_errors(input_variables, fixed_variables, wrong_variables, empty_variables): +def _check_for_errors(input_variables, fixed_variables, wrong_variables, empty_variables) -> None: if any(var for var in input_variables if var not in fixed_variables): error_message = ( f"Error: Input variables contain invalid characters or formats. \n" @@ -159,7 +159,7 @@ def get_old_custom_fields(custom_fields, name): return old_custom_fields -def add_new_variables_to_template(input_variables, custom_fields, template, name): +def add_new_variables_to_template(input_variables, custom_fields, template, name) -> None: for variable in input_variables: try: template_field = DefaultPromptField(name=variable, display_name=variable) @@ -177,7 +177,7 @@ def add_new_variables_to_template(input_variables, custom_fields, template, name raise HTTPException(status_code=500, detail=str(exc)) from exc -def remove_old_variables_from_template(old_custom_fields, input_variables, custom_fields, template, name): +def remove_old_variables_from_template(old_custom_fields, input_variables, custom_fields, template, name) -> None: for variable in old_custom_fields: if variable not in input_variables: try: @@ -192,7 +192,7 @@ def remove_old_variables_from_template(old_custom_fields, input_variables, custo raise HTTPException(status_code=500, detail=str(exc)) from exc -def update_input_variables_field(input_variables, template): +def update_input_variables_field(input_variables, template) -> None: if "input_variables" in template: template["input_variables"]["value"] = input_variables diff --git a/src/backend/base/langflow/base/textsplitters/model.py b/src/backend/base/langflow/base/textsplitters/model.py index 6574bf946fbb..40d3b928136f 100644 --- a/src/backend/base/langflow/base/textsplitters/model.py +++ b/src/backend/base/langflow/base/textsplitters/model.py @@ -9,7 +9,7 @@ class LCTextSplitterComponent(LCDocumentTransformerComponent): trace_type = "text_splitter" - def _validate_outputs(self): + def _validate_outputs(self) -> None: required_output_methods = ["text_splitter"] output_names = [output.name for output in self.outputs] for method_name in required_output_methods: diff --git a/src/backend/base/langflow/base/tools/component_tool.py b/src/backend/base/langflow/base/tools/component_tool.py index a21452da7e2a..0cc08b54d10c 100644 --- a/src/backend/base/langflow/base/tools/component_tool.py +++ b/src/backend/base/langflow/base/tools/component_tool.py @@ -27,7 +27,7 @@ def _get_input_type(_input: InputTypes): return _input.field_type -def build_description(component: Component, output: Output): +def build_description(component: Component, output: Output) -> str: if not output.required_inputs: logger.warning(f"Output {output.name} does not have required inputs defined") diff --git a/src/backend/base/langflow/base/vectorstores/model.py b/src/backend/base/langflow/base/vectorstores/model.py index d8d01967366b..3750ddb310aa 100644 --- a/src/backend/base/langflow/base/vectorstores/model.py +++ b/src/backend/base/langflow/base/vectorstores/model.py @@ -71,7 +71,7 @@ def __init_subclass__(cls, **kwargs): ), ] - def _validate_outputs(self): + def _validate_outputs(self) -> None: # At least these three outputs must be defined required_output_methods = [ "build_base_retriever", diff --git a/src/backend/base/langflow/components/astra_assistants/astra_assistant_manager.py b/src/backend/base/langflow/components/astra_assistants/astra_assistant_manager.py index 438ecb67dfca..3d4b8ad3afad 100644 --- a/src/backend/base/langflow/components/astra_assistants/astra_assistant_manager.py +++ b/src/backend/base/langflow/components/astra_assistants/astra_assistant_manager.py @@ -67,39 +67,39 @@ class AstraAssistantManager(ComponentWithCache): Output(display_name="Assistant Id", name="output_assistant_id", method="get_assistant_id"), ] - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.lock = asyncio.Lock() - self.initialized = False - self.assistant_response = None - self.tool_output = None - self.thread_id = None - self.assistant_id = None + self.initialized: bool = False + self._assistant_response: Message = None # type: ignore[assignment] + self._tool_output: Message = None # type: ignore[assignment] + self._thread_id: Message = None # type: ignore[assignment] + self._assistant_id: Message = None # type: ignore[assignment] self.client = get_patched_openai_client(self._shared_component_cache) async def get_assistant_response(self) -> Message: await self.initialize() - return self.assistant_response + return self._assistant_response async def get_tool_output(self) -> Message: await self.initialize() - return self.tool_output + return self._tool_output async def get_thread_id(self) -> Message: await self.initialize() - return self.thread_id + return self._thread_id async def get_assistant_id(self) -> Message: await self.initialize() - return self.assistant_id + return self._assistant_id - async def initialize(self): + async def initialize(self) -> None: async with self.lock: if not self.initialized: await self.process_inputs() self.initialized = True - async def process_inputs(self): + async def process_inputs(self) -> None: logger.info(f"env_set is {self.env_set}") logger.info(self.tool) tools = [] @@ -126,10 +126,10 @@ async def process_inputs(self): content = self.user_message result = await assistant_manager.run_thread(content=content, tool=tool_obj) - self.assistant_response = Message(text=result["text"]) + self._assistant_response = Message(text=result["text"]) if "decision" in result: - self.tool_output = Message(text=str(result["decision"].is_complete)) + self._tool_output = Message(text=str(result["decision"].is_complete)) else: - self.tool_output = Message(text=result["text"]) - self.thread_id = Message(text=assistant_manager.thread.id) - self.assistant_id = Message(text=assistant_manager.assistant.id) + self._tool_output = Message(text=result["text"]) + self._thread_id = Message(text=assistant_manager.thread.id) + self._assistant_id = Message(text=assistant_manager.assistant.id) diff --git a/src/backend/base/langflow/components/astra_assistants/create_assistant.py b/src/backend/base/langflow/components/astra_assistants/create_assistant.py index 1bb48e01ef63..b2d354efd287 100644 --- a/src/backend/base/langflow/components/astra_assistants/create_assistant.py +++ b/src/backend/base/langflow/components/astra_assistants/create_assistant.py @@ -45,7 +45,7 @@ class AssistantsCreateAssistant(ComponentWithCache): Output(display_name="Assistant ID", name="assistant_id", method="process_inputs"), ] - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.client = get_patched_openai_client(self._shared_component_cache) diff --git a/src/backend/base/langflow/components/astra_assistants/create_thread.py b/src/backend/base/langflow/components/astra_assistants/create_thread.py index 7aa40e623767..d7848bd2c977 100644 --- a/src/backend/base/langflow/components/astra_assistants/create_thread.py +++ b/src/backend/base/langflow/components/astra_assistants/create_thread.py @@ -21,7 +21,7 @@ class AssistantsCreateThread(ComponentWithCache): Output(display_name="Thread ID", name="thread_id", method="process_inputs"), ] - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.client = get_patched_openai_client(self._shared_component_cache) diff --git a/src/backend/base/langflow/components/astra_assistants/get_assistant.py b/src/backend/base/langflow/components/astra_assistants/get_assistant.py index 1db2afd510fa..61810ad90a47 100644 --- a/src/backend/base/langflow/components/astra_assistants/get_assistant.py +++ b/src/backend/base/langflow/components/astra_assistants/get_assistant.py @@ -26,7 +26,7 @@ class AssistantsGetAssistantName(ComponentWithCache): Output(display_name="Assistant Name", name="assistant_name", method="process_inputs"), ] - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.client = get_patched_openai_client(self._shared_component_cache) diff --git a/src/backend/base/langflow/components/astra_assistants/list_assistants.py b/src/backend/base/langflow/components/astra_assistants/list_assistants.py index 762e16c3c25f..ec6fd058d2f2 100644 --- a/src/backend/base/langflow/components/astra_assistants/list_assistants.py +++ b/src/backend/base/langflow/components/astra_assistants/list_assistants.py @@ -12,7 +12,7 @@ class AssistantsListAssistants(ComponentWithCache): Output(display_name="Assistants", name="assistants", method="process_inputs"), ] - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.client = get_patched_openai_client(self._shared_component_cache) diff --git a/src/backend/base/langflow/components/astra_assistants/run.py b/src/backend/base/langflow/components/astra_assistants/run.py index dbd6422f9199..d7526a4fb03f 100644 --- a/src/backend/base/langflow/components/astra_assistants/run.py +++ b/src/backend/base/langflow/components/astra_assistants/run.py @@ -16,7 +16,7 @@ class AssistantsRun(ComponentWithCache): display_name = "Run Assistant" description = "Executes an Assistant Run against a thread" - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self.client = get_patched_openai_client(self._shared_component_cache) self.thread_id = None @@ -26,7 +26,7 @@ def update_build_config( build_config: dotdict, field_value: Any, field_name: str | None = None, - ): + ) -> None: if field_name == "thread_id": if field_value is None: thread = self.client.beta.threads.create() @@ -75,7 +75,7 @@ def process_inputs(self) -> Message: self.client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=self.user_message) class EventHandler(AssistantEventHandler): - def __init__(self): + def __init__(self) -> None: super().__init__() def on_exception(self, exception: Exception) -> None: diff --git a/src/backend/base/langflow/components/data/GoogleDriveSearch.py b/src/backend/base/langflow/components/data/GoogleDriveSearch.py index ecc91c95ebcf..824c68cdda35 100644 --- a/src/backend/base/langflow/components/data/GoogleDriveSearch.py +++ b/src/backend/base/langflow/components/data/GoogleDriveSearch.py @@ -88,7 +88,7 @@ def generate_query_string(self) -> str: return query - def on_inputs_changed(self): + def on_inputs_changed(self) -> None: # Automatically regenerate the query string when inputs change self.generate_query_string() diff --git a/src/backend/base/langflow/components/embeddings/GoogleGenerativeAIEmbeddings.py b/src/backend/base/langflow/components/embeddings/GoogleGenerativeAIEmbeddings.py index 10f9f9096db5..edd2d63dd223 100644 --- a/src/backend/base/langflow/components/embeddings/GoogleGenerativeAIEmbeddings.py +++ b/src/backend/base/langflow/components/embeddings/GoogleGenerativeAIEmbeddings.py @@ -36,7 +36,7 @@ def build_embeddings(self) -> Embeddings: raise ValueError(msg) class HotaGoogleGenerativeAIEmbeddings(GoogleGenerativeAIEmbeddings): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super(GoogleGenerativeAIEmbeddings, self).__init__(*args, **kwargs) def embed_documents( diff --git a/src/backend/base/langflow/components/prototypes/CreateData.py b/src/backend/base/langflow/components/prototypes/CreateData.py index 140db0b18d16..265c838e1c7c 100644 --- a/src/backend/base/langflow/components/prototypes/CreateData.py +++ b/src/backend/base/langflow/components/prototypes/CreateData.py @@ -99,7 +99,7 @@ def get_data(self): data.update(_value_dict) return data - def validate_text_key(self): + def validate_text_key(self) -> None: """This function validates that the Text Key is one of the keys in the Data.""" data_keys = self.get_data().keys() if self.text_key not in data_keys and self.text_key != "": diff --git a/src/backend/base/langflow/components/prototypes/UpdateData.py b/src/backend/base/langflow/components/prototypes/UpdateData.py index 7f3962ac98f9..2f91823e3c16 100644 --- a/src/backend/base/langflow/components/prototypes/UpdateData.py +++ b/src/backend/base/langflow/components/prototypes/UpdateData.py @@ -106,7 +106,7 @@ def get_data(self): data.update(_value_dict) return data - def validate_text_key(self, data: Data): + def validate_text_key(self, data: Data) -> None: """This function validates that the Text Key is one of the keys in the Data.""" data_keys = data.data.keys() if self.text_key not in data_keys and self.text_key != "": diff --git a/src/backend/base/langflow/components/vectorstores/AstraDB.py b/src/backend/base/langflow/components/vectorstores/AstraDB.py index c7d559b654d9..34687a838b28 100644 --- a/src/backend/base/langflow/components/vectorstores/AstraDB.py +++ b/src/backend/base/langflow/components/vectorstores/AstraDB.py @@ -434,7 +434,7 @@ def build_vector_store(self, vectorize_options=None): return vector_store - def _add_documents_to_vector_store(self, vector_store): + def _add_documents_to_vector_store(self, vector_store) -> None: documents = [] for _input in self.ingest_data or []: if isinstance(_input, Data): @@ -453,7 +453,7 @@ def _add_documents_to_vector_store(self, vector_store): else: logger.debug("No documents to add to the Vector Store.") - def _map_search_type(self): + def _map_search_type(self) -> str: if self.search_type == "Similarity with score threshold": return "similarity_score_threshold" if self.search_type == "MMR (Max Marginal Relevance)": diff --git a/src/backend/base/langflow/components/vectorstores/Cassandra.py b/src/backend/base/langflow/components/vectorstores/Cassandra.py index 870ebcfeb255..b2db432e91bb 100644 --- a/src/backend/base/langflow/components/vectorstores/Cassandra.py +++ b/src/backend/base/langflow/components/vectorstores/Cassandra.py @@ -206,7 +206,7 @@ def build_vector_store(self) -> Cassandra: ) return table - def _map_search_type(self): + def _map_search_type(self) -> str: if self.search_type == "Similarity with score threshold": return "similarity_score_threshold" if self.search_type == "MMR (Max Marginal Relevance)": diff --git a/src/backend/base/langflow/components/vectorstores/CassandraGraph.py b/src/backend/base/langflow/components/vectorstores/CassandraGraph.py index 74bdd50b2cf8..740ed003e18f 100644 --- a/src/backend/base/langflow/components/vectorstores/CassandraGraph.py +++ b/src/backend/base/langflow/components/vectorstores/CassandraGraph.py @@ -181,7 +181,7 @@ def build_vector_store(self) -> CassandraGraphVectorStore: ) return store - def _map_search_type(self): + def _map_search_type(self) -> str: if self.search_type == "Similarity": return "similarity" if self.search_type == "Similarity with score threshold": diff --git a/src/backend/base/langflow/components/vectorstores/HCD.py b/src/backend/base/langflow/components/vectorstores/HCD.py index be0e6eb99482..488a586d9aa4 100644 --- a/src/backend/base/langflow/components/vectorstores/HCD.py +++ b/src/backend/base/langflow/components/vectorstores/HCD.py @@ -253,7 +253,7 @@ def build_vector_store(self): self._add_documents_to_vector_store(vector_store) return vector_store - def _add_documents_to_vector_store(self, vector_store): + def _add_documents_to_vector_store(self, vector_store) -> None: documents = [] for _input in self.ingest_data or []: if isinstance(_input, Data): @@ -272,7 +272,7 @@ def _add_documents_to_vector_store(self, vector_store): else: logger.debug("No documents to add to the Vector Store.") - def _map_search_type(self): + def _map_search_type(self) -> str: if self.search_type == "Similarity with score threshold": return "similarity_score_threshold" if self.search_type == "MMR (Max Marginal Relevance)": diff --git a/src/backend/base/langflow/custom/code_parser/code_parser.py b/src/backend/base/langflow/custom/code_parser/code_parser.py index 092d1fa9bc1e..85ae07b46568 100644 --- a/src/backend/base/langflow/custom/code_parser/code_parser.py +++ b/src/backend/base/langflow/custom/code_parser/code_parser.py @@ -318,7 +318,7 @@ def parse_classes(self, node: ast.ClassDef) -> None: self.process_class_node(_node, class_details) self.data["classes"].append(class_details.model_dump()) - def process_class_node(self, node, class_details): + def process_class_node(self, node, class_details) -> None: for stmt in node.body: if isinstance(stmt, ast.Assign): if attr := self.parse_assign(stmt): diff --git a/src/backend/base/langflow/custom/custom_component/base_component.py b/src/backend/base/langflow/custom/custom_component/base_component.py index cdb2e5e948d5..5d1d63117b99 100644 --- a/src/backend/base/langflow/custom/custom_component/base_component.py +++ b/src/backend/base/langflow/custom/custom_component/base_component.py @@ -31,15 +31,15 @@ class BaseComponent: _user_id: str | UUID | None = None _template_config: dict = {} - def __init__(self, **data): - self.cache = TTLCache(maxsize=1024, ttl=60) + def __init__(self, **data) -> None: + self.cache: TTLCache = TTLCache(maxsize=1024, ttl=60) for key, value in data.items(): if key == "user_id": self._user_id = value else: setattr(self, key, value) - def __setattr__(self, key, value): + def __setattr__(self, key, value) -> None: if key == "_user_id" and self._user_id is not None: logger.warning("user_id is immutable and cannot be changed.") super().__setattr__(key, value) diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index 7f0c675e4887..b816f6af0fd2 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -61,7 +61,7 @@ class Component(CustomComponent): _current_output: str = "" _metadata: dict = {} - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: # if key starts with _ it is a config # else it is an input inputs = {} @@ -107,10 +107,10 @@ def __init__(self, **kwargs): self.set_class_code() self._set_output_required_inputs() - def set_event_manager(self, event_manager: EventManager | None = None): + def set_event_manager(self, event_manager: EventManager | None = None) -> None: self._event_manager = event_manager - def _reset_all_output_values(self): + def _reset_all_output_values(self) -> None: if isinstance(self._outputs_map, dict): for output in self._outputs_map.values(): output.value = UNDEFINED @@ -153,7 +153,7 @@ def __deepcopy__(self, memo): memo[id(self)] = new_component return new_component - def set_class_code(self): + def set_class_code(self) -> None: # Get the source code of the calling class if self._code: return @@ -200,7 +200,7 @@ async def run(self): """ return await self._run() - def set_vertex(self, vertex: Vertex): + def set_vertex(self, vertex: Vertex) -> None: """Sets the vertex for the component. Args: @@ -245,7 +245,7 @@ def get_output(self, name: str) -> Any: msg = f"Output {name} not found in {self.__class__.__name__}" raise ValueError(msg) - def set_on_output(self, name: str, **kwargs): + def set_on_output(self, name: str, **kwargs) -> None: output = self.get_output(name) for key, value in kwargs.items(): if not hasattr(output, key): @@ -253,14 +253,14 @@ def set_on_output(self, name: str, **kwargs): raise ValueError(msg) setattr(output, key, value) - def set_output_value(self, name: str, value: Any): + def set_output_value(self, name: str, value: Any) -> None: if name in self._outputs_map: self._outputs_map[name].value = value else: msg = f"Output {name} not found in {self.__class__.__name__}" raise ValueError(msg) - def map_outputs(self, outputs: list[Output]): + def map_outputs(self, outputs: list[Output]) -> None: """Maps the given list of outputs to the component. Args: @@ -280,7 +280,7 @@ def map_outputs(self, outputs: list[Output]): # allows each instance of each component to modify its own output self._outputs_map[output.name] = deepcopy(output) - def map_inputs(self, inputs: list[InputTypes]): + def map_inputs(self, inputs: list[InputTypes]) -> None: """Maps the given inputs to the component. Args: @@ -296,7 +296,7 @@ def map_inputs(self, inputs: list[InputTypes]): raise ValueError(msg) self._inputs[input_.name] = deepcopy(input_) - def validate(self, params: dict): + def validate(self, params: dict) -> None: """Validates the component parameters. Args: @@ -309,13 +309,16 @@ def validate(self, params: dict): self._validate_inputs(params) self._validate_outputs() - def _set_output_types(self): + def _set_output_types(self) -> None: for output in self._outputs_map.values(): + if output.method is None: + msg = f"Output {output.name} does not have a method" + raise ValueError(msg) return_types = self._get_method_return_type(output.method) output.add_types(return_types) output.set_selected() - def _set_output_required_inputs(self): + def _set_output_required_inputs(self) -> None: for output in self.outputs: if not output.method: continue @@ -326,8 +329,7 @@ def _set_output_required_inputs(self): source_code = inspect.getsource(method) ast_tree = ast.parse(dedent(source_code)) except Exception: # noqa: BLE001 - source_code = self._code - ast_tree = ast.parse(dedent(source_code)) + ast_tree = ast.parse(dedent(self._code or "")) visitor = RequiredInputsVisitor(self._inputs) visitor.visit(ast_tree) @@ -420,7 +422,7 @@ def _find_matching_output_method(self, input_name: str, value: Component): raise TypeError(msg) return getattr(value, output.method) - def _process_connection_or_parameter(self, key, value): + def _process_connection_or_parameter(self, key, value) -> None: _input = self._get_or_create_input(key) # We need to check if callable AND if it is a method from a class that inherits from Component if isinstance(value, Component): @@ -438,7 +440,7 @@ def _process_connection_or_parameter(self, key, value): else: self._set_parameter_or_attribute(key, value) - def _process_connection_or_parameters(self, key, value): + def _process_connection_or_parameters(self, key, value) -> None: # if value is a list of components, we need to process each component if isinstance(value, list): for val in value: @@ -455,13 +457,13 @@ def _get_or_create_input(self, key): self.inputs.append(_input) return _input - def _connect_to_component(self, key, value, _input): + def _connect_to_component(self, key, value, _input) -> None: component = value.__self__ self._components.append(component) output = component.get_output_by_method(value) self._add_edge(component, key, output, _input) - def _add_edge(self, component, key, output, _input): + def _add_edge(self, component, key, output, _input) -> None: self._edges.append( { "source": component._id, @@ -483,7 +485,7 @@ def _add_edge(self, component, key, output, _input): } ) - def _set_parameter_or_attribute(self, key, value): + def _set_parameter_or_attribute(self, key, value) -> None: if isinstance(value, Component): methods = ", ".join([f"'{output.method}'" for output in value.outputs]) msg = ( @@ -527,7 +529,7 @@ def __getattr__(self, name: str) -> Any: msg = f"{name} not found in {self.__class__.__name__}" raise AttributeError(msg) - def _set_input_value(self, name: str, value: Any): + def _set_input_value(self, name: str, value: Any) -> None: if name in self._inputs: input_value = self._inputs[name].value if isinstance(input_value, Component): @@ -547,15 +549,15 @@ def _set_input_value(self, name: str, value: Any): msg = f"Input {name} not found in {self.__class__.__name__}" raise ValueError(msg) - def _validate_outputs(self): + def _validate_outputs(self) -> None: # Raise Error if some rule isn't met pass - def _map_parameters_on_frontend_node(self, frontend_node: ComponentFrontendNode): + def _map_parameters_on_frontend_node(self, frontend_node: ComponentFrontendNode) -> None: for name, value in self._parameters.items(): frontend_node.set_field_value_in_template(name, value) - def _map_parameters_on_template(self, template: dict): + def _map_parameters_on_template(self, template: dict) -> None: for name, value in self._parameters.items(): try: template[name]["value"] = value @@ -625,7 +627,7 @@ def to_frontend_node(self): "id": self._id, } - def _validate_inputs(self, params: dict): + def _validate_inputs(self, params: dict) -> None: # Params keys are the `name` attribute of the Input objects for key, value in params.copy().items(): if key not in self._inputs: @@ -636,7 +638,7 @@ def _validate_inputs(self, params: dict): input_.value = value params[input_.name] = input_.value - def set_attributes(self, params: dict): + def set_attributes(self, params: dict) -> None: self._validate_inputs(params) _attributes = {} for key, value in params.items(): @@ -652,7 +654,7 @@ def set_attributes(self, params: dict): _attributes[key] = input_obj.value or None self._attributes = _attributes - def _set_outputs(self, outputs: list[dict]): + def _set_outputs(self, outputs: list[dict]) -> None: self.outputs = [Output(**output) for output in outputs] for output in self.outputs: setattr(self, output.name, output) @@ -794,7 +796,7 @@ def _get_field_order(self): except KeyError: return [] - def build(self, **kwargs): + def build(self, **kwargs) -> None: self.set_attributes(kwargs) def _get_fallback_input(self, **kwargs): @@ -809,7 +811,7 @@ def get_project_name(self): return self._tracing_service.project_name return "Langflow" - def log(self, message: LoggableType | list[LoggableType], name: str | None = None): + def log(self, message: LoggableType | list[LoggableType], name: str | None = None) -> None: """Logs a message. Args: @@ -828,6 +830,6 @@ def log(self, message: LoggableType | list[LoggableType], name: str | None = Non data["component_id"] = self._id self._event_manager.on_log(data=data) - def _append_tool_output(self): + def _append_tool_output(self) -> None: if next((output for output in self.outputs if output.name == TOOL_OUTPUT_NAME), None) is None: self.outputs.append(Output(name=TOOL_OUTPUT_NAME, display_name="Tool", method="to_toolkit", types=["Tool"])) diff --git a/src/backend/base/langflow/custom/custom_component/component_with_cache.py b/src/backend/base/langflow/custom/custom_component/component_with_cache.py index 90f440799dad..e8a6888d9039 100644 --- a/src/backend/base/langflow/custom/custom_component/component_with_cache.py +++ b/src/backend/base/langflow/custom/custom_component/component_with_cache.py @@ -3,6 +3,6 @@ class ComponentWithCache(Component): - def __init__(self, **data): + def __init__(self, **data) -> None: super().__init__(**data) self._shared_component_cache = get_shared_component_cache_service() diff --git a/src/backend/base/langflow/custom/custom_component/custom_component.py b/src/backend/base/langflow/custom/custom_component/custom_component.py index 6b47a2937d32..7c5a6eec496d 100644 --- a/src/backend/base/langflow/custom/custom_component/custom_component.py +++ b/src/backend/base/langflow/custom/custom_component/custom_component.py @@ -85,7 +85,7 @@ class CustomComponent(BaseComponent): _tracing_service: TracingService | None = None _tree: dict | None = None - def __init__(self, **data): + def __init__(self, **data) -> None: """Initializes a new instance of the CustomComponent class. Args: @@ -93,22 +93,25 @@ def __init__(self, **data): """ self.cache = TTLCache(maxsize=1024, ttl=60) self._logs = [] - self._results = {} - self._artifacts = {} + self._results: dict = {} + self._artifacts: dict = {} super().__init__(**data) - def set_attributes(self, parameters: dict): + def set_attributes(self, parameters: dict) -> None: pass - def set_parameters(self, parameters: dict): + def set_parameters(self, parameters: dict) -> None: self._parameters = parameters self.set_attributes(self._parameters) @property - def trace_name(self): + def trace_name(self) -> str: + if self._vertex is None: + msg = "Vertex is not set" + raise ValueError(msg) return f"{self.display_name} ({self._vertex.id})" - def update_state(self, name: str, value: Any): + def update_state(self, name: str, value: Any) -> None: if not self._vertex: msg = "Vertex is not set" raise ValueError(msg) @@ -118,7 +121,7 @@ def update_state(self, name: str, value: Any): msg = f"Error updating state: {e}" raise ValueError(msg) from e - def stop(self, output_name: str | None = None): + def stop(self, output_name: str | None = None) -> None: if not output_name and self._vertex and len(self._vertex.outputs) == 1: output_name = self._vertex.outputs[0]["name"] elif not output_name: @@ -133,7 +136,7 @@ def stop(self, output_name: str | None = None): msg = f"Error stopping {self.display_name}: {e}" raise ValueError(msg) from e - def append_state(self, name: str, value: Any): + def append_state(self, name: str, value: Any) -> None: if not self._vertex: msg = "Vertex is not set" raise ValueError(msg) diff --git a/src/backend/base/langflow/custom/directory_reader/directory_reader.py b/src/backend/base/langflow/custom/directory_reader/directory_reader.py index 452ceb2cec94..70b453b52d3b 100644 --- a/src/backend/base/langflow/custom/directory_reader/directory_reader.py +++ b/src/backend/base/langflow/custom/directory_reader/directory_reader.py @@ -13,7 +13,7 @@ class CustomComponentPathValueError(ValueError): class StringCompressor: - def __init__(self, input_string): + def __init__(self, input_string) -> None: """Initialize StringCompressor with a string to compress.""" self.input_string = input_string @@ -39,7 +39,7 @@ class DirectoryReader: # the custom components from this directory. base_path = "" - def __init__(self, directory_path, *, compress_code_field=False): + def __init__(self, directory_path, *, compress_code_field=False) -> None: """Initialize DirectoryReader with a directory path and a flag indicating whether to compress the code.""" self.directory_path = directory_path self.compress_code_field = compress_code_field @@ -76,7 +76,7 @@ def filter_loaded_components(self, data: dict, *, with_errors: bool) -> dict: logger.debug(f'Filtered components {"with errors" if with_errors else ""}: {len(filtered)}') return {"menu": filtered} - def validate_code(self, file_content): + def validate_code(self, file_content) -> bool: """Validate the Python code by trying to parse it with ast.parse.""" try: ast.parse(file_content) diff --git a/src/backend/base/langflow/custom/directory_reader/utils.py b/src/backend/base/langflow/custom/directory_reader/utils.py index 24e3e551fe3e..d982252a540d 100644 --- a/src/backend/base/langflow/custom/directory_reader/utils.py +++ b/src/backend/base/langflow/custom/directory_reader/utils.py @@ -109,7 +109,7 @@ def create_invalid_component_template(component, component_name): return component_frontend_node.model_dump(by_alias=True, exclude_none=True) -def log_invalid_component_details(component): +def log_invalid_component_details(component) -> None: """Log details of an invalid component.""" logger.debug(component) logger.debug(f"Component Path: {component.get('path', None)}") diff --git a/src/backend/base/langflow/custom/schema.py b/src/backend/base/langflow/custom/schema.py index c60285d6b863..5c90356b38cd 100644 --- a/src/backend/base/langflow/custom/schema.py +++ b/src/backend/base/langflow/custom/schema.py @@ -28,5 +28,5 @@ class CallableCodeDetails(BaseModel): class MissingDefault: """A class to represent a missing default value.""" - def __repr__(self): + def __repr__(self) -> str: return "MISSING" diff --git a/src/backend/base/langflow/custom/tree_visitor.py b/src/backend/base/langflow/custom/tree_visitor.py index 8bd14526fd94..1e863eead09c 100644 --- a/src/backend/base/langflow/custom/tree_visitor.py +++ b/src/backend/base/langflow/custom/tree_visitor.py @@ -1,15 +1,16 @@ import ast +from typing import Any from typing_extensions import override class RequiredInputsVisitor(ast.NodeVisitor): - def __init__(self, inputs): - self.inputs = inputs - self.required_inputs = set() + def __init__(self, inputs: dict[str, Any]): + self.inputs: dict[str, Any] = inputs + self.required_inputs: set[str] = set() @override - def visit_Attribute(self, node): + def visit_Attribute(self, node) -> None: if isinstance(node.value, ast.Name) and node.value.id == "self" and node.attr in self.inputs: self.required_inputs.add(node.attr) self.generic_visit(node) diff --git a/src/backend/base/langflow/custom/utils.py b/src/backend/base/langflow/custom/utils.py index 5d01f781e619..23468defae33 100644 --- a/src/backend/base/langflow/custom/utils.py +++ b/src/backend/base/langflow/custom/utils.py @@ -32,7 +32,7 @@ class UpdateBuildConfigError(Exception): pass -def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: list[str]): +def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: list[str]) -> None: """Add output types to the frontend node.""" for return_type in return_types: if return_type is None: @@ -55,7 +55,7 @@ def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: l frontend_node.add_output_type(_return_type) -def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: list[str]): +def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: list[str]) -> None: """Reorder fields in the frontend node based on the specified field_order.""" if not field_order: return @@ -69,7 +69,7 @@ def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: list frontend_node.field_order = field_order -def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: list[str]): +def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: list[str]) -> None: """Add base classes to the frontend node.""" for return_type_instance in return_types: if return_type_instance is None: @@ -196,7 +196,7 @@ def add_new_custom_field( return frontend_node -def add_extra_fields(frontend_node, field_config, function_args): +def add_extra_fields(frontend_node, field_config, function_args) -> None: """Add extra fields to the frontend node.""" if not function_args: return diff --git a/src/backend/base/langflow/events/event_manager.py b/src/backend/base/langflow/events/event_manager.py index 6390ca45dfda..4567e7bceafa 100644 --- a/src/backend/base/langflow/events/event_manager.py +++ b/src/backend/base/langflow/events/event_manager.py @@ -25,7 +25,7 @@ def __init__(self, queue: asyncio.Queue): self.events: dict[str, PartialEventCallback] = {} @staticmethod - def _validate_callback(callback: EventCallback): + def _validate_callback(callback: EventCallback) -> None: if not callable(callback): msg = "Callback must be callable" raise TypeError(msg) @@ -39,7 +39,7 @@ def _validate_callback(callback: EventCallback): msg = "Callback must have exactly 3 parameters: manager, event_type, and data" raise ValueError(msg) - def register_event(self, name: str, event_type: str, callback: EventCallback | None = None): + def register_event(self, name: str, event_type: str, callback: EventCallback | None = None) -> None: if not name: msg = "Event name cannot be empty" raise ValueError(msg) @@ -52,14 +52,14 @@ def register_event(self, name: str, event_type: str, callback: EventCallback | N _callback = partial(callback, manager=self, event_type=event_type) self.events[name] = _callback - def send_event(self, *, event_type: str, data: LoggableType): + def send_event(self, *, event_type: str, data: LoggableType) -> None: jsonable_data = jsonable_encoder(data) json_data = {"event": event_type, "data": jsonable_data} event_id = uuid.uuid4() str_data = json.dumps(json_data) + "\n\n" self.queue.put_nowait((event_id, str_data.encode("utf-8"), time.time())) - def noop(self, *, data: LoggableType): + def noop(self, *, data: LoggableType) -> None: pass def __getattr__(self, name: str) -> PartialEventCallback: diff --git a/src/backend/base/langflow/graph/graph/ascii.py b/src/backend/base/langflow/graph/graph/ascii.py index 2bee1f1197d2..46cfa245b6f2 100644 --- a/src/backend/base/langflow/graph/graph/ascii.py +++ b/src/backend/base/langflow/graph/graph/ascii.py @@ -24,7 +24,7 @@ class VertexViewer: HEIGHT = 3 # top and bottom box edges + text - def __init__(self, name): + def __init__(self, name) -> None: self._h = self.HEIGHT # top and bottom box edges + text self._w = len(name) + 2 # right and left bottom edges + text @@ -40,7 +40,7 @@ def w(self): class AsciiCanvas: """Class for drawing in ASCII.""" - def __init__(self, cols, lines): + def __init__(self, cols, lines) -> None: assert cols > 1 assert lines > 1 self.cols = cols @@ -53,19 +53,19 @@ def get_lines(self): def draws(self): return "\n".join(self.get_lines()) - def draw(self): + def draw(self) -> None: """Draws ASCII canvas on the screen.""" lines = self.get_lines() print("\n".join(lines)) # noqa: T201 - def point(self, x, y, char): + def point(self, x, y, char) -> None: """Create a point on ASCII canvas.""" assert len(char) == 1 assert 0 <= x < self.cols assert 0 <= y < self.lines self.canvas[y][x] = char - def line(self, x0, y0, x1, y1, char): + def line(self, x0, y0, x1, y1, char) -> None: """Create a line on ASCII canvas.""" if x0 > x1: x1, x0 = x0, x1 @@ -85,12 +85,12 @@ def line(self, x0, y0, x1, y1, char): x = x0 + int(round((y - y0) * dx / float(dy))) if dy else x0 self.point(x, y, char) - def text(self, x, y, text): + def text(self, x, y, text) -> None: """Print a text on ASCII canvas.""" for i, char in enumerate(text): self.point(x + i, y, char) - def box(self, x0, y0, width, height): + def box(self, x0, y0, width, height) -> None: """Create a box on ASCII canvas.""" assert width > 1 assert height > 1 diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index 717b8b025dfb..ca5dd094f5e5 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -201,7 +201,7 @@ def dump( graph_dict["endpoint_name"] = str(endpoint_name) return graph_dict - def add_nodes_and_edges(self, nodes: list[NodeData], edges: list[EdgeData]): + def add_nodes_and_edges(self, nodes: list[NodeData], edges: list[EdgeData]) -> None: self._vertices = nodes self._edges = edges self.raw_graph_data = {"nodes": nodes, "edges": edges} @@ -238,7 +238,7 @@ def add_component(self, component: Component, component_id: str | None = None) - return component_id - def _set_start_and_end(self, start: Component, end: Component): + def _set_start_and_end(self, start: Component, end: Component) -> None: if not hasattr(start, "to_frontend_node"): msg = f"start must be a Component. Got {type(start)}" raise TypeError(msg) @@ -248,7 +248,7 @@ def _set_start_and_end(self, start: Component, end: Component): self.add_component(start, start._id) self.add_component(end, end._id) - def add_component_edge(self, source_id: str, output_input_tuple: tuple[str, str], target_id: str): + def add_component_edge(self, source_id: str, output_input_tuple: tuple[str, str], target_id: str) -> None: source_vertex = self.get_vertex(source_id) if not isinstance(source_vertex, ComponentVertex): msg = f"Source vertex {source_id} is not a component vertex." @@ -337,7 +337,7 @@ def _snapshot(self): "run_manager": copy.deepcopy(self.run_manager.to_dict()), } - def __apply_config(self, config: StartConfigDict): + def __apply_config(self, config: StartConfigDict) -> None: for vertex in self.vertices: if vertex._custom_component is None: continue @@ -373,7 +373,7 @@ def start( except StopAsyncIteration: break - def _add_edge(self, edge: EdgeData): + def _add_edge(self, edge: EdgeData) -> None: self.add_edge(edge) source_id = edge["data"]["sourceHandle"]["id"] target_id = edge["data"]["targetHandle"]["id"] @@ -382,16 +382,16 @@ def _add_edge(self, edge: EdgeData): self.in_degree_map[target_id] += 1 self.parent_child_map[source_id].append(target_id) - def add_node(self, node: NodeData): + def add_node(self, node: NodeData) -> None: self._vertices.append(node) - def add_edge(self, edge: EdgeData): + def add_edge(self, edge: EdgeData) -> None: # Check if the edge already exists if edge in self._edges: return self._edges.append(edge) - def initialize(self): + def initialize(self) -> None: self._build_graph() self.build_graph_maps(self.edges) self.define_vertices_lists() @@ -424,7 +424,7 @@ def update_state(self, name: str, record: str | Data, caller: str | None = None) self.state_manager.update_state(name, record, run_id=self._run_id) - def activate_state_vertices(self, name: str, caller: str): + def activate_state_vertices(self, name: str, caller: str) -> None: """Activates the state vertices in the graph with the given name and caller. Args: @@ -473,7 +473,7 @@ def activate_state_vertices(self, name: str, caller: str): vertices_to_run=self.vertices_to_run, ) - def reset_activated_vertices(self): + def reset_activated_vertices(self) -> None: """Resets the activated vertices in the graph.""" self.activated_vertices = [] @@ -490,7 +490,7 @@ def append_state(self, name: str, record: str | Data, caller: str | None = None) self.state_manager.append_state(name, record, run_id=self._run_id) - def validate_stream(self): + def validate_stream(self) -> None: """Validates the stream configuration of the graph. If there are two vertices in the same graph (connected by edges) @@ -548,7 +548,7 @@ def run_id(self): raise ValueError(msg) return self._run_id - def set_run_id(self, run_id: uuid.UUID | None = None): + def set_run_id(self, run_id: uuid.UUID | None = None) -> None: """Sets the ID of the current run. Args: @@ -564,7 +564,7 @@ def set_run_id(self, run_id: uuid.UUID | None = None): if self.tracing_service: self.tracing_service.set_run_id(run_id) - def set_run_name(self): + def set_run_name(self) -> None: # Given a flow name, flow_id if not self.tracing_service: return @@ -573,16 +573,16 @@ def set_run_name(self): self.set_run_id() self.tracing_service.set_run_name(name) - async def initialize_run(self): + async def initialize_run(self) -> None: if self.tracing_service: await self.tracing_service.initialize_tracers() - def _end_all_traces_async(self, outputs: dict[str, Any] | None = None, error: Exception | None = None): + def _end_all_traces_async(self, outputs: dict[str, Any] | None = None, error: Exception | None = None) -> None: task = asyncio.create_task(self.end_all_traces(outputs, error)) self._end_trace_tasks.add(task) task.add_done_callback(self._end_trace_tasks.discard) - async def end_all_traces(self, outputs: dict[str, Any] | None = None, error: Exception | None = None): + async def end_all_traces(self, outputs: dict[str, Any] | None = None, error: Exception | None = None) -> None: if not self.tracing_service: return self._end_time = datetime.now(timezone.utc) @@ -602,7 +602,7 @@ def sorted_vertices_layers(self) -> list[list[str]]: self.sort_vertices() return self._sorted_vertices_layers - def define_vertices_lists(self): + def define_vertices_lists(self) -> None: """Defines the lists of vertices that are inputs, outputs, and have session_id.""" attributes = ["is_input", "is_output", "has_session_id", "is_state"] for vertex in self.vertices: @@ -610,7 +610,7 @@ def define_vertices_lists(self): if getattr(vertex, attribute): getattr(self, f"_{attribute}_vertices").append(vertex.id) - def _set_inputs(self, input_components: list[str], inputs: dict[str, str], input_type: InputType | None): + def _set_inputs(self, input_components: list[str], inputs: dict[str, str], input_type: InputType | None) -> None: for vertex_id in self._is_input_vertices: vertex = self.get_vertex(vertex_id) # If the vertex is not in the input_components list @@ -838,7 +838,7 @@ def metadata(self): "flow_name": self.flow_name, } - def build_graph_maps(self, edges: list[CycleEdge] | None = None, vertices: list[Vertex] | None = None): + def build_graph_maps(self, edges: list[CycleEdge] | None = None, vertices: list[Vertex] | None = None) -> None: """Builds the adjacency maps for the graph.""" if edges is None: edges = self.edges @@ -851,26 +851,28 @@ def build_graph_maps(self, edges: list[CycleEdge] | None = None, vertices: list[ self.in_degree_map = self.build_in_degree(edges) self.parent_child_map = self.build_parent_child_map(vertices) - def reset_inactivated_vertices(self): + def reset_inactivated_vertices(self) -> None: """Resets the inactivated vertices in the graph.""" for vertex_id in self.inactivated_vertices.copy(): self.mark_vertex(vertex_id, "ACTIVE") - self.inactivated_vertices = [] + self.inactivated_vertices = set() self.inactivated_vertices = set() - def mark_all_vertices(self, state: str): + def mark_all_vertices(self, state: str) -> None: """Marks all vertices in the graph.""" for vertex in self.vertices: vertex.set_state(state) - def mark_vertex(self, vertex_id: str, state: str): + def mark_vertex(self, vertex_id: str, state: str) -> None: """Marks a vertex in the graph.""" vertex = self.get_vertex(vertex_id) vertex.set_state(state) if state == VertexStates.INACTIVE: self.run_manager.remove_from_predecessors(vertex_id) - def _mark_branch(self, vertex_id: str, state: str, visited: set | None = None, output_name: str | None = None): + def _mark_branch( + self, vertex_id: str, state: str, visited: set | None = None, output_name: str | None = None + ) -> None: """Marks a branch of the graph.""" if visited is None: visited = set() @@ -889,7 +891,7 @@ def _mark_branch(self, vertex_id: str, state: str, visited: set | None = None, o continue self._mark_branch(child_id, state, visited) - def mark_branch(self, vertex_id: str, state: str, output_name: str | None = None): + def mark_branch(self, vertex_id: str, state: str, output_name: str | None = None) -> None: self._mark_branch(vertex_id=vertex_id, state=state, output_name=output_name) new_predecessor_map, _ = self.build_adjacency_maps(self.edges) self.run_manager.update_run_state( @@ -910,10 +912,10 @@ def build_parent_child_map(self, vertices: list[Vertex]): parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)] return parent_child_map - def increment_run_count(self): + def increment_run_count(self) -> None: self._runs += 1 - def increment_update_count(self): + def increment_update_count(self) -> None: self._updates += 1 def __getstate__(self): @@ -1239,7 +1241,7 @@ def get_next_in_queue(self): return None return self._run_queue.popleft() - def extend_run_queue(self, vertices: list[str]): + def extend_run_queue(self, vertices: list[str]) -> None: self._run_queue.extend(vertices) async def astep( @@ -1292,7 +1294,7 @@ def get_snapshot(self): } ) - def _record_snapshot(self, vertex_id: str | None = None): + def _record_snapshot(self, vertex_id: str | None = None) -> None: self._snapshots.append(self.get_snapshot()) if vertex_id: self._call_order.append(vertex_id) @@ -1557,7 +1559,7 @@ def topological_sort(self) -> list[Vertex]: state = dict.fromkeys(self.vertices, 0) sorted_vertices = [] - def dfs(vertex): + def dfs(vertex) -> None: if state[vertex] == 1: # We have a cycle msg = "Graph contains a cycle, cannot perform topological sort" @@ -1770,7 +1772,7 @@ def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> list[ children.append(vertex) return children - def __repr__(self): + def __repr__(self) -> str: vertex_ids = [vertex.id for vertex in self.vertices] edges_repr = "\n".join([f" {edge.source_id} --> {edge.target_id}" for edge in self.edges]) @@ -2009,7 +2011,7 @@ def is_vertex_runnable(self, vertex_id: str) -> bool: is_active = self.get_vertex(vertex_id).is_active() return self.run_manager.is_vertex_runnable(vertex_id, is_active=is_active) - def build_run_map(self): + def build_run_map(self) -> None: """Builds the run map for the graph. This method is responsible for building the run map for the graph, @@ -2036,7 +2038,7 @@ def find_runnable_predecessors_for_successor(self, vertex_id: str) -> list[str]: runnable_vertices = [] visited = set() - def find_runnable_predecessors(predecessor: Vertex): + def find_runnable_predecessors(predecessor: Vertex) -> None: predecessor_id = predecessor.id if predecessor_id in visited: return @@ -2052,10 +2054,10 @@ def find_runnable_predecessors(predecessor: Vertex): find_runnable_predecessors(self.get_vertex(predecessor_id)) return runnable_vertices - def remove_from_predecessors(self, vertex_id: str): + def remove_from_predecessors(self, vertex_id: str) -> None: self.run_manager.remove_from_predecessors(vertex_id) - def remove_vertex_from_runnables(self, vertex_id: str): + def remove_vertex_from_runnables(self, vertex_id: str) -> None: self.run_manager.remove_vertex_from_runnables(vertex_id) def get_top_level_vertices(self, vertices_ids): diff --git a/src/backend/base/langflow/graph/graph/constants.py b/src/backend/base/langflow/graph/graph/constants.py index 90395bdefe25..b5a1e411d583 100644 --- a/src/backend/base/langflow/graph/graph/constants.py +++ b/src/backend/base/langflow/graph/graph/constants.py @@ -3,7 +3,7 @@ class Finish: - def __bool__(self): + def __bool__(self) -> bool: return True def __eq__(self, other): @@ -17,7 +17,7 @@ def _import_vertex_types(): class VertexTypesDict(LazyLoadDictBase): - def __init__(self): + def __init__(self) -> None: self._all_types_dict = None self._types = _import_vertex_types diff --git a/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py b/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py index 01431dac15aa..fece73cd6361 100644 --- a/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py +++ b/src/backend/base/langflow/graph/graph/runnable_vertices_manager.py @@ -2,11 +2,11 @@ class RunnableVerticesManager: - def __init__(self): - self.run_map = defaultdict(list) # Tracks successors of each vertex - self.run_predecessors = defaultdict(set) # Tracks predecessors for each vertex - self.vertices_to_run = set() # Set of vertices that are ready to run - self.vertices_being_run = set() # Set of vertices that are currently running + def __init__(self) -> None: + self.run_map: dict[str, list[str]] = defaultdict(list) # Tracks successors of each vertex + self.run_predecessors: dict[str, set[str]] = defaultdict(set) # Tracks predecessors for each vertex + self.vertices_to_run: set[str] = set() # Set of vertices that are ready to run + self.vertices_being_run: set[str] = set() # Set of vertices that are currently running def to_dict(self) -> dict: return { @@ -42,7 +42,7 @@ def __setstate__(self, state: dict) -> None: def all_predecessors_are_fulfilled(self) -> bool: return all(not value for value in self.run_predecessors.values()) - def update_run_state(self, run_predecessors: dict, vertices_to_run: set): + def update_run_state(self, run_predecessors: dict, vertices_to_run: set) -> None: self.run_predecessors.update(run_predecessors) self.vertices_to_run.update(vertices_to_run) self.build_run_map(self.run_predecessors, self.vertices_to_run) @@ -60,14 +60,14 @@ def is_vertex_runnable(self, vertex_id: str, *, is_active: bool) -> bool: def are_all_predecessors_fulfilled(self, vertex_id: str) -> bool: return not any(self.run_predecessors.get(vertex_id, [])) - def remove_from_predecessors(self, vertex_id: str): + def remove_from_predecessors(self, vertex_id: str) -> None: """Removes a vertex from the predecessor list of its successors.""" predecessors = self.run_map.get(vertex_id, []) for predecessor in predecessors: if vertex_id in self.run_predecessors[predecessor]: self.run_predecessors[predecessor].remove(vertex_id) - def build_run_map(self, predecessor_map, vertices_to_run): + def build_run_map(self, predecessor_map, vertices_to_run) -> None: """Builds a map of vertices and their runnable successors.""" self.run_map = defaultdict(list) for vertex_id, predecessors in predecessor_map.items(): @@ -76,16 +76,16 @@ def build_run_map(self, predecessor_map, vertices_to_run): self.run_predecessors = predecessor_map.copy() self.vertices_to_run = vertices_to_run - def update_vertex_run_state(self, vertex_id: str, *, is_runnable: bool): + def update_vertex_run_state(self, vertex_id: str, *, is_runnable: bool) -> None: """Updates the runnable state of a vertex.""" if is_runnable: self.vertices_to_run.add(vertex_id) else: self.vertices_being_run.discard(vertex_id) - def remove_vertex_from_runnables(self, v_id): + def remove_vertex_from_runnables(self, v_id) -> None: self.update_vertex_run_state(v_id, is_runnable=False) self.remove_from_predecessors(v_id) - def add_to_vertices_being_run(self, v_id): + def add_to_vertices_being_run(self, v_id) -> None: self.vertices_being_run.add(v_id) diff --git a/src/backend/base/langflow/graph/graph/state_manager.py b/src/backend/base/langflow/graph/graph/state_manager.py index 035e79290205..38aed4b90740 100644 --- a/src/backend/base/langflow/graph/graph/state_manager.py +++ b/src/backend/base/langflow/graph/graph/state_manager.py @@ -13,7 +13,7 @@ class GraphStateManager: - def __init__(self): + def __init__(self) -> None: try: self.state_service: StateService = get_state_service() except Exception: # noqa: BLE001 @@ -22,26 +22,14 @@ def __init__(self): self.state_service = InMemoryStateService(get_settings_service()) - def append_state(self, key, new_state, run_id: str): + def append_state(self, key, new_state, run_id: str) -> None: self.state_service.append_state(key, new_state, run_id) - def update_state(self, key, new_state, run_id: str): + def update_state(self, key, new_state, run_id: str) -> None: self.state_service.update_state(key, new_state, run_id) def get_state(self, key, run_id: str): return self.state_service.get_state(key, run_id) - def subscribe(self, key, observer: Callable): + def subscribe(self, key, observer: Callable) -> None: self.state_service.subscribe(key, observer) - - def notify_observers(self, key, new_state): - for callback in self.observers[key]: - callback(key, new_state, append=False) - - def notify_append_observers(self, key, new_state): - for callback in self.observers[key]: - try: - callback(key, new_state, append=True) - except Exception: # noqa: BLE001 - logger.exception(f"Error in observer {callback} for key {key}") - logger.warning("Callbacks not implemented yet") diff --git a/src/backend/base/langflow/graph/graph/utils.py b/src/backend/base/langflow/graph/graph/utils.py index 25603e849bdb..02ea8e0b7b09 100644 --- a/src/backend/base/langflow/graph/graph/utils.py +++ b/src/backend/base/langflow/graph/graph/utils.py @@ -27,13 +27,13 @@ def find_last_node(nodes, edges): return next((n for n in nodes if all(e["source"] != n["id"] for e in edges)), None) -def add_parent_node_id(nodes, parent_node_id): +def add_parent_node_id(nodes, parent_node_id) -> None: """This function receives a list of nodes and adds a parent_node_id to each node.""" for node in nodes: node["parent_node_id"] = parent_node_id -def add_frozen(nodes, frozen): +def add_frozen(nodes, frozen) -> None: """This function receives a list of nodes and adds a frozen to each node.""" for node in nodes: node["data"]["node"]["frozen"] = frozen @@ -75,7 +75,7 @@ def process_flow(flow_object): cloned_flow = copy.deepcopy(flow_object) processed_nodes = set() # To keep track of processed nodes - def process_node(node): + def process_node(node) -> None: node_id = node.get("id") # If node already processed, skip @@ -100,7 +100,7 @@ def process_node(node): return cloned_flow -def update_template(template, g_nodes): +def update_template(template, g_nodes) -> None: """Updates the template of a node in a graph with the given template. Args: @@ -149,7 +149,7 @@ def update_target_handle(new_edge, g_nodes): return new_edge -def set_new_target_handle(proxy_id, new_edge, target_handle, node): +def set_new_target_handle(proxy_id, new_edge, target_handle, node) -> None: """Sets a new target handle for a given edge. Args: @@ -330,7 +330,7 @@ def has_cycle(vertex_ids: list[str], edges: list[tuple[str, str]]) -> bool: graph[u].append(v) # Utility function to perform DFS - def dfs(v, visited, rec_stack): + def dfs(v, visited, rec_stack) -> bool: visited.add(v) rec_stack.add(v) diff --git a/src/backend/base/langflow/graph/state/model.py b/src/backend/base/langflow/graph/state/model.py index cbc12a9b0b89..b56e886389ba 100644 --- a/src/backend/base/langflow/graph/state/model.py +++ b/src/backend/base/langflow/graph/state/model.py @@ -126,10 +126,10 @@ def build_output_setter(method: Callable, *, validate: bool = True) -> Callable: >>> print(component.get_output_by_method(component.set_message).value) # Prints "New message" """ - def output_setter(self, value): # noqa: ARG001 + def output_setter(self, value) -> None: # noqa: ARG001 if validate: __validate_method(method) - methods_class = method.__self__ + methods_class = method.__self__ # type: ignore[attr-defined] output = methods_class.get_output_by_method(method) output.value = value diff --git a/src/backend/base/langflow/graph/utils.py b/src/backend/base/langflow/graph/utils.py index 4a466a5dd8e7..5ceee698ba9c 100644 --- a/src/backend/base/langflow/graph/utils.py +++ b/src/backend/base/langflow/graph/utils.py @@ -172,7 +172,7 @@ def log_vertex_build( params: Any, data: ResultDataResponse, artifacts: dict | None = None, -): +) -> None: try: if not get_settings_service().settings.vertex_builds_storage_enabled: return diff --git a/src/backend/base/langflow/graph/vertex/base.py b/src/backend/base/langflow/graph/vertex/base.py index be327c1fe14e..5fbe1d965e04 100644 --- a/src/backend/base/langflow/graph/vertex/base.py +++ b/src/backend/base/langflow/graph/vertex/base.py @@ -75,8 +75,8 @@ def __init__( self.base_type: str | None = base_type self.outputs: list[dict] = [] self._parse_data() - self._built_object = UnbuiltObject() - self._built_result = None + self._built_object: Any = UnbuiltObject() + self._built_result: Any = None self._built = False self._successors_ids: list[str] | None = None self.artifacts: dict[str, Any] = {} @@ -106,7 +106,7 @@ def __init__( self.state = VertexStates.ACTIVE self.log_transaction_tasks: set[asyncio.Task] = set() - def set_input_value(self, name: str, value: Any): + def set_input_value(self, name: str, value: Any) -> None: if self._custom_component is None: msg = f"Vertex {self.id} does not have a component instance." raise ValueError(msg) @@ -115,20 +115,20 @@ def set_input_value(self, name: str, value: Any): def to_data(self): return self._data - def add_component_instance(self, component_instance: Component): + def add_component_instance(self, component_instance: Component) -> None: component_instance.set_vertex(self) self._custom_component = component_instance - def add_result(self, name: str, result: Any): + def add_result(self, name: str, result: Any) -> None: self.results[name] = result - def update_graph_state(self, key, new_state, *, append: bool): + def update_graph_state(self, key, new_state, *, append: bool) -> None: if append: self.graph.append_state(key, new_state, caller=self.id) else: self.graph.update_state(key, new_state, caller=self.id) - def set_state(self, state: str): + def set_state(self, state: str) -> None: self.state = VertexStates[state] if self.state == VertexStates.INACTIVE and self.graph.in_degree_map[self.id] <= 1: # If the vertex is inactive and has only one in degree @@ -144,7 +144,7 @@ def is_active(self): def avg_build_time(self): return sum(self.build_times) / len(self.build_times) if self.build_times else 0 - def add_build_time(self, time): + def add_build_time(self, time) -> None: self.build_times.append(time) def set_result(self, result: ResultData) -> None: @@ -300,7 +300,7 @@ def _set_params_from_normal_edge(self, params: dict, edge: Edge, template_dict: params[param_key] = self.graph.get_vertex(edge.source_id) return params - def _build_params(self): + def _build_params(self) -> None: # sourcery skip: merge-list-append, remove-redundant-if # Some params are required, some are optional # but most importantly, some params are python base classes @@ -326,7 +326,7 @@ def _build_params(self): return template_dict = {key: value for key, value in self.data["node"]["template"].items() if isinstance(value, dict)} - params = {} + params: dict = {} for edge in self.edges: if not hasattr(edge, "target_param"): @@ -438,7 +438,7 @@ def _build_params(self): self.load_from_db_fields = load_from_db_fields self._raw_params = params.copy() - def update_raw_params(self, new_params: Mapping[str, str | list[str]], *, overwrite: bool = False): + def update_raw_params(self, new_params: Mapping[str, str | list[str]], *, overwrite: bool = False) -> None: """Update the raw parameters of the vertex with the given new parameters. Args: @@ -466,7 +466,7 @@ def has_cycle_edges(self): """Checks if the vertex has any cycle edges.""" return self._has_cycle_edges - async def instantiate_component(self, user_id=None): + async def instantiate_component(self, user_id=None) -> None: if not self._custom_component: self._custom_component, _ = await initialize.loading.instantiate_class( user_id=user_id, @@ -478,7 +478,7 @@ async def _build( fallback_to_env_vars, user_id=None, event_manager: EventManager | None = None, - ): + ) -> None: """Initiate the build process.""" logger.debug(f"Building {self.display_name}") await self._build_each_vertex_in_params_dict() @@ -500,6 +500,7 @@ async def _build( custom_component=custom_component, custom_params=custom_params, fallback_to_env_vars=fallback_to_env_vars, + base_type=self.base_type, ) self._validate_built_object() @@ -545,7 +546,7 @@ def extract_messages_from_artifacts(self, artifacts: dict[str, Any]) -> list[dic return messages - def _finalize_build(self): + def _finalize_build(self) -> None: result_dict = self.get_built_result() # We need to set the artifacts to pass information # to the frontend @@ -563,7 +564,7 @@ def _finalize_build(self): ) self.set_result(result_dict) - async def _build_each_vertex_in_params_dict(self): + async def _build_each_vertex_in_params_dict(self) -> None: """Iterates over each vertex in the params dictionary and builds it.""" for key, value in self._raw_params.items(): if self._is_vertex(value): @@ -588,7 +589,7 @@ async def _build_dict_and_update_params( self, key, vertices_dict: dict[str, Vertex], - ): + ) -> None: """Iterates over a dictionary of vertices, builds each and updates the params dictionary.""" for sub_key, value in vertices_dict.items(): if not self._is_vertex(value): @@ -647,7 +648,7 @@ async def _get_result( self._log_transaction_async(str(flow_id), source=self, target=requester, status="success") return result - async def _build_vertex_and_update_params(self, key, vertex: Vertex): + async def _build_vertex_and_update_params(self, key, vertex: Vertex) -> None: """Builds a given vertex and updates the params dictionary accordingly.""" result = await vertex.get_result(self, target_handle_name=key) self._handle_func(key, result) @@ -659,7 +660,7 @@ async def _build_list_of_vertices_and_update_params( self, key, vertices: list[Vertex], - ): + ) -> None: """Iterates over a list of vertices, builds each and updates the params dictionary.""" self.params[key] = [] for vertex in vertices: @@ -685,7 +686,7 @@ async def _build_list_of_vertices_and_update_params( ) raise ValueError(msg) from e - def _handle_func(self, key, result): + def _handle_func(self, key, result) -> None: """Handles 'func' key by checking if the result is a function and setting it as coroutine.""" if key == "func": if not isinstance(result, types.FunctionType): @@ -698,19 +699,21 @@ def _handle_func(self, key, result): else: self.params["coroutine"] = sync_to_async(result) - def _extend_params_list_with_result(self, key, result): + def _extend_params_list_with_result(self, key, result) -> None: """Extends a list in the params dictionary with the given result if it exists.""" if isinstance(self.params[key], list): self.params[key].extend(result) - async def _build_results(self, custom_component, custom_params, *, fallback_to_env_vars=False): + async def _build_results( + self, custom_component, custom_params, base_type: str, *, fallback_to_env_vars=False + ) -> None: try: result = await initialize.loading.get_instance_results( custom_component=custom_component, custom_params=custom_params, vertex=self, fallback_to_env_vars=fallback_to_env_vars, - base_type=self.base_type, + base_type=base_type, ) self.outputs_logs = build_output_logs(self, result) @@ -722,7 +725,7 @@ async def _build_results(self, custom_component, custom_params, *, fallback_to_e msg = f"Error building Component {self.display_name}: \n\n{exc}" raise ComponentBuildError(msg, tb) from exc - def _update_built_object_and_artifacts(self, result: Any | tuple[Any, dict] | tuple[Component, Any, dict]): + def _update_built_object_and_artifacts(self, result: Any | tuple[Any, dict] | tuple[Component, Any, dict]) -> None: """Updates the built object and its artifacts.""" if isinstance(result, tuple): if len(result) == 2: # noqa: PLR2004 @@ -738,7 +741,7 @@ def _update_built_object_and_artifacts(self, result: Any | tuple[Any, dict] | tu else: self._built_object = result - def _validate_built_object(self): + def _validate_built_object(self) -> None: """Checks if the built object is None and raises a ValueError if so.""" if isinstance(self._built_object, UnbuiltObject): msg = f"{self.display_name}: {self._built_object_repr()}" @@ -754,7 +757,7 @@ def _validate_built_object(self): msg = f"You are trying to stream to a {self.display_name}. Try using a Chat Output instead." raise ValueError(msg) - def _reset(self): + def _reset(self) -> None: self._built = False self._built_object = UnbuiltObject() self._built_result = UnbuiltResult() @@ -762,10 +765,10 @@ def _reset(self): self.steps_ran = [] self._build_params() - def _is_chat_input(self): + def _is_chat_input(self) -> bool: return False - def build_inactive(self): + def build_inactive(self) -> None: # Just set the results to None self._built = True self._built_object = None @@ -865,11 +868,11 @@ def __eq__(self, __o: object) -> bool: def __hash__(self) -> int: return id(self) - def _built_object_repr(self): + def _built_object_repr(self) -> str: # Add a message with an emoji, stars for sucess, return "Built successfully ✨" if self._built_object is not None else "Failed to build 😵‍💫" - def apply_on_outputs(self, func: Callable[[Any], Any]): + def apply_on_outputs(self, func: Callable[[Any], Any]) -> None: """Applies a function to the outputs of the vertex.""" if not self._custom_component or not self._custom_component.outputs: return diff --git a/src/backend/base/langflow/graph/vertex/types.py b/src/backend/base/langflow/graph/vertex/types.py index 4decf39b35bc..bb3eb6926a4a 100644 --- a/src/backend/base/langflow/graph/vertex/types.py +++ b/src/backend/base/langflow/graph/vertex/types.py @@ -57,7 +57,7 @@ def _built_object_repr(self): return self.artifacts["repr"] or super()._built_object_repr() return None - def _update_built_object_and_artifacts(self, result): + def _update_built_object_and_artifacts(self, result) -> None: """Updates the built object and its artifacts.""" if isinstance(result, tuple): if len(result) == 2: # noqa: PLR2004 @@ -182,7 +182,7 @@ def extract_messages_from_artifacts(self, artifacts: dict[str, Any]) -> list[dic ) return messages - def _finalize_build(self): + def _finalize_build(self) -> None: result_dict = self.get_built_result() # We need to set the artifacts to pass information # to the frontend @@ -206,7 +206,7 @@ def __init__(self, data: NodeData, graph): self.steps = [self._build, self._run] self.is_interface_component = True - def build_stream_url(self): + def build_stream_url(self) -> str: return f"/api/v1/build/{self.graph.flow_id}/{self.id}/stream" def _built_object_repr(self): @@ -352,21 +352,17 @@ def _process_data_component(self): self.artifacts = DataOutputResponse(data=artifacts) return self._built_object - async def _run(self, *args, **kwargs): - if self.is_interface_component: - if self.vertex_type in CHAT_COMPONENTS: - message = self._process_chat_component() - elif self.vertex_type in RECORDS_COMPONENTS: - message = self._process_data_component() - if isinstance(self._built_object, AsyncIterator | Iterator): - if self.params.get("return_data", False): - self._built_object = Data(text=message, data=self.artifacts) - else: - self._built_object = message - self._built_result = self._built_object - - else: - await super()._run(*args, **kwargs) + async def _run(self, *args, **kwargs) -> None: # noqa: ARG002 + if self.vertex_type in CHAT_COMPONENTS: + message = self._process_chat_component() + elif self.vertex_type in RECORDS_COMPONENTS: + message = self._process_data_component() + if isinstance(self._built_object, AsyncIterator | Iterator): + if self.params.get("return_data", False): + self._built_object = Data(text=message, data=self.artifacts) + else: + self._built_object = message + self._built_result = self._built_object async def stream(self): iterator = self.params.get(INPUT_FIELD_NAME, None) @@ -452,7 +448,7 @@ async def stream(self): self._validate_built_object() self._built = True - async def consume_async_generator(self): + async def consume_async_generator(self) -> None: async for _ in self.stream(): pass diff --git a/src/backend/base/langflow/initial_setup/setup.py b/src/backend/base/langflow/initial_setup/setup.py index 5c84f942d4d6..14deb499f25f 100644 --- a/src/backend/base/langflow/initial_setup/setup.py +++ b/src/backend/base/langflow/initial_setup/setup.py @@ -12,6 +12,7 @@ import orjson from emoji import demojize, purely_emoji from loguru import logger +from sqlalchemy.exc import NoResultFound from sqlmodel import select from langflow.base.constants import ( @@ -340,7 +341,7 @@ def update_edges_with_latest_component_versions(project_data): return project_data_copy -def log_node_changes(node_changes_log): +def log_node_changes(node_changes_log) -> None: # The idea here is to log the changes that were made to the nodes in debug # Something like: # Node: "Node Name" was updated with the following changes: @@ -377,8 +378,11 @@ def load_starter_projects(retries=3, delay=1) -> list[tuple[Path, dict]]: return starter_projects -def copy_profile_pictures(): +def copy_profile_pictures() -> None: config_dir = get_storage_service().settings_service.settings.config_dir + if config_dir is None: + msg = "Config dir is not set in the settings" + raise ValueError(msg) origin = Path(__file__).parent / "profile_pictures" target = Path(config_dir) / "profile_pictures" @@ -425,7 +429,7 @@ def get_project_data(project): ) -def update_project_file(project_path: Path, project: dict, updated_project_data): +def update_project_file(project_path: Path, project: dict, updated_project_data) -> None: project["data"] = updated_project_data project_path.write_text(orjson.dumps(project, option=ORJSON_OPTIONS).decode(), encoding="utf-8") logger.info(f"Updated starter project {project['name']} file") @@ -440,7 +444,7 @@ def update_existing_project( project_data, project_icon, project_icon_bg_color, -): +) -> None: logger.info(f"Updating starter project {project_name}") existing_project.data = project_data existing_project.folder = STARTER_FOLDER_NAME @@ -463,7 +467,7 @@ def create_new_project( project_icon, project_icon_bg_color, new_folder_id, -): +) -> None: logger.debug(f"Creating starter project {project_name}") new_project = FlowCreate( name=project_name, @@ -485,7 +489,7 @@ def get_all_flows_similar_to_project(session, folder_id): return session.exec(select(Folder).where(Folder.id == folder_id)).first().flows -def delete_start_projects(session, folder_id): +def delete_start_projects(session, folder_id) -> None: flows = session.exec(select(Folder).where(Folder.id == folder_id)).first().flows for flow in flows: session.delete(flow) @@ -516,7 +520,7 @@ def _is_valid_uuid(val): return str(uuid_obj) == val -def load_flows_from_directory(): +def load_flows_from_directory() -> None: """On langflow startup, this loads all flows from the directory specified in the settings. All flows are uploaded into the default folder for the superuser. @@ -531,7 +535,11 @@ def load_flows_from_directory(): return with session_scope() as session: - user_id = get_user_by_username(session, settings_service.auth_settings.SUPERUSER).id + user = get_user_by_username(session, settings_service.auth_settings.SUPERUSER) + if user is None: + msg = "Superuser not found in the database" + raise NoResultFound(msg) + user_id = user.id _flows_path = Path(flows_path) files = [f for f in _flows_path.iterdir() if f.is_file()] for f in files: @@ -592,7 +600,7 @@ def find_existing_flow(session, flow_id, flow_endpoint_name): return None -async def create_or_update_starter_projects(get_all_components_coro: Awaitable[dict]): +async def create_or_update_starter_projects(get_all_components_coro: Awaitable[dict]) -> None: try: all_types_dict = await get_all_components_coro except Exception: @@ -647,7 +655,7 @@ async def create_or_update_starter_projects(get_all_components_coro: Awaitable[d ) -def initialize_super_user_if_needed(): +def initialize_super_user_if_needed() -> None: settings_service = get_settings_service() if not settings_service.auth_settings.AUTO_LOGIN: return diff --git a/src/backend/base/langflow/interface/listing.py b/src/backend/base/langflow/interface/listing.py index 2a950d4c0888..d7f88b311037 100644 --- a/src/backend/base/langflow/interface/listing.py +++ b/src/backend/base/langflow/interface/listing.py @@ -3,7 +3,7 @@ class AllTypesDict(LazyLoadDictBase): - def __init__(self): + def __init__(self) -> None: self._all_types_dict = None def _build_dict(self): diff --git a/src/backend/base/langflow/interface/run.py b/src/backend/base/langflow/interface/run.py index 5effecd7f2d1..aa2051f7da3c 100644 --- a/src/backend/base/langflow/interface/run.py +++ b/src/backend/base/langflow/interface/run.py @@ -19,7 +19,7 @@ def get_memory_key(langchain_object): return None # or some other default value or action -def update_memory_keys(langchain_object, possible_new_mem_key): +def update_memory_keys(langchain_object, possible_new_mem_key) -> None: """Update the memory keys in the LangChain object's memory attribute. Given a LangChain object and a possible new memory key, this function updates the input and output keys in the diff --git a/src/backend/base/langflow/interface/utils.py b/src/backend/base/langflow/interface/utils.py index df7847c5eb23..900560ea9321 100644 --- a/src/backend/base/langflow/interface/utils.py +++ b/src/backend/base/langflow/interface/utils.py @@ -89,7 +89,7 @@ def extract_input_variables_from_prompt(prompt: str) -> list[str]: return variables -def setup_llm_caching(): +def setup_llm_caching() -> None: """Setup LLM caching.""" settings_service = get_settings_service() try: @@ -100,7 +100,7 @@ def setup_llm_caching(): logger.opt(exception=True).warning("Could not setup LLM caching.") -def set_langchain_cache(settings): +def set_langchain_cache(settings) -> None: from langchain.globals import set_llm_cache from langflow.interface.importing.utils import import_class diff --git a/src/backend/base/langflow/logging/logger.py b/src/backend/base/langflow/logging/logger.py index 8263eebf34c9..02540dd8cb37 100644 --- a/src/backend/base/langflow/logging/logger.py +++ b/src/backend/base/langflow/logging/logger.py @@ -42,7 +42,7 @@ def __init__( def get_write_lock(self) -> Lock: return self._wlock - def write(self, message: str): + def write(self, message: str) -> None: record = json.loads(message) log_entry = record["text"] epoch = int(record["record"]["time"]["timestamp"] * 1000) @@ -52,7 +52,7 @@ def write(self, message: str): self.buffer.popleft() self.buffer.append((epoch, log_entry)) - def __len__(self): + def __len__(self) -> int: return len(self.buffer) def get_after_timestamp(self, timestamp: int, lines: int = 5) -> dict[int, str]: @@ -123,7 +123,7 @@ def serialize_log(record): return orjson.dumps(subset) -def patching(record): +def patching(record) -> None: record["extra"]["serialized"] = serialize_log(record) if DEV is False: record.pop("exception", None) @@ -142,7 +142,7 @@ def configure( log_file: Path | None = None, disable: bool | None = False, log_env: str | None = None, -): +) -> None: if disable and log_level is None and log_file is None: logger.disable("langflow") if os.getenv("LANGFLOW_LOG_LEVEL", "").upper() in VALID_LOG_LEVELS and log_level is None: @@ -205,14 +205,14 @@ def configure( setup_gunicorn_logger() -def setup_uvicorn_logger(): +def setup_uvicorn_logger() -> None: loggers = (logging.getLogger(name) for name in logging.root.manager.loggerDict if name.startswith("uvicorn.")) for uvicorn_logger in loggers: uvicorn_logger.handlers = [] logging.getLogger("uvicorn").handlers = [InterceptHandler()] -def setup_gunicorn_logger(): +def setup_gunicorn_logger() -> None: logging.getLogger("gunicorn.error").handlers = [InterceptHandler()] logging.getLogger("gunicorn.access").handlers = [InterceptHandler()] @@ -223,7 +223,7 @@ class InterceptHandler(logging.Handler): See https://loguru.readthedocs.io/en/stable/overview.html#entirely-compatible-with-standard-logging. """ - def emit(self, record): + def emit(self, record) -> None: # Get corresponding Loguru level if it exists try: level = logger.level(record.levelname).name @@ -232,7 +232,7 @@ def emit(self, record): # Find caller from where originated the logged message frame, depth = logging.currentframe(), 2 - while frame.f_code.co_filename == logging.__file__: + while frame.f_code.co_filename == logging.__file__ and frame.f_back: frame = frame.f_back depth += 1 diff --git a/src/backend/base/langflow/logging/setup.py b/src/backend/base/langflow/logging/setup.py index d99cc8f36fae..2d207b28f8ac 100644 --- a/src/backend/base/langflow/logging/setup.py +++ b/src/backend/base/langflow/logging/setup.py @@ -3,14 +3,14 @@ LOGGING_CONFIGURED = False -def disable_logging(): +def disable_logging() -> None: global LOGGING_CONFIGURED # noqa: PLW0603 if not LOGGING_CONFIGURED: logger.disable("langflow") LOGGING_CONFIGURED = True -def enable_logging(): +def enable_logging() -> None: global LOGGING_CONFIGURED # noqa: PLW0603 logger.enable("langflow") LOGGING_CONFIGURED = True diff --git a/src/backend/base/langflow/main.py b/src/backend/base/langflow/main.py index fafbd6157d05..507ac180d2f9 100644 --- a/src/backend/base/langflow/main.py +++ b/src/backend/base/langflow/main.py @@ -40,7 +40,7 @@ class RequestCancelledMiddleware(BaseHTTPMiddleware): - def __init__(self, app): + def __init__(self, app) -> None: super().__init__(app) async def dispatch(self, request: Request, call_next): @@ -224,7 +224,7 @@ async def exception_handler(_request: Request, exc: Exception): return app -def setup_sentry(app: FastAPI): +def setup_sentry(app: FastAPI) -> None: settings = get_settings_service().settings if settings.sentry_dsn: import sentry_sdk @@ -238,7 +238,7 @@ def setup_sentry(app: FastAPI): app.add_middleware(SentryAsgiMiddleware) -def setup_static_files(app: FastAPI, static_files_dir: Path): +def setup_static_files(app: FastAPI, static_files_dir: Path) -> None: """Setup the static files directory. Args: diff --git a/src/backend/base/langflow/memory.py b/src/backend/base/langflow/memory.py index 72d81e16ccf2..86958c98aa00 100644 --- a/src/backend/base/langflow/memory.py +++ b/src/backend/base/langflow/memory.py @@ -86,7 +86,7 @@ def add_messagetables(messages: list[MessageTable], session: Session): return [MessageRead.model_validate(message, from_attributes=True) for message in messages] -def delete_messages(session_id: str): +def delete_messages(session_id: str) -> None: """Delete messages from the monitor service based on the provided session ID. Args: diff --git a/src/backend/base/langflow/processing/base.py b/src/backend/base/langflow/processing/base.py index 96bad52a9f11..3ef0909dbbdc 100644 --- a/src/backend/base/langflow/processing/base.py +++ b/src/backend/base/langflow/processing/base.py @@ -36,7 +36,7 @@ def get_langfuse_callback(trace_id): return None -def flush_langfuse_callback_if_present(callbacks: list[BaseCallbackHandler | CallbackHandler]): +def flush_langfuse_callback_if_present(callbacks: list[BaseCallbackHandler | CallbackHandler]) -> None: """If langfuse callback is present, run callback.langfuse.flush().""" for callback in callbacks: if hasattr(callback, "langfuse") and hasattr(callback.langfuse, "flush"): diff --git a/src/backend/base/langflow/schema/data.py b/src/backend/base/langflow/schema/data.py index 80e669aad37c..fbf73f50440e 100644 --- a/src/backend/base/langflow/schema/data.py +++ b/src/backend/base/langflow/schema/data.py @@ -165,7 +165,7 @@ def __getattr__(self, key): msg = f"'{type(self).__name__}' object has no attribute '{key}'" raise AttributeError(msg) from e - def __setattr__(self, key, value): + def __setattr__(self, key, value) -> None: """Set attribute-like values in the data dictionary. Allows attribute-like setting of values in the data dictionary. @@ -179,7 +179,7 @@ def __setattr__(self, key, value): else: self.data[key] = value - def __delattr__(self, key): + def __delattr__(self, key) -> None: """Allows attribute-like deletion from the data dictionary.""" if key in {"data", "text_key"} or key.startswith("_"): super().__delattr__(key) @@ -204,7 +204,7 @@ def __str__(self) -> str: logger.opt(exception=True).debug("Error converting Data to JSON") return str(self.data) - def __contains__(self, key): + def __contains__(self, key) -> bool: return key in self.data def __eq__(self, other): diff --git a/src/backend/base/langflow/schema/dotdict.py b/src/backend/base/langflow/schema/dotdict.py index f7e30ee6157d..93d57aec93f5 100644 --- a/src/backend/base/langflow/schema/dotdict.py +++ b/src/backend/base/langflow/schema/dotdict.py @@ -33,7 +33,7 @@ def __getattr__(self, attr): else: return value - def __setattr__(self, key, value): + def __setattr__(self, key, value) -> None: """Override attribute setting to work as dictionary item assignment. Args: @@ -44,7 +44,7 @@ def __setattr__(self, key, value): value = dotdict(value) self[key] = value - def __delattr__(self, key): + def __delattr__(self, key) -> None: """Override attribute deletion to work as dictionary item deletion. Args: diff --git a/src/backend/base/langflow/schema/graph.py b/src/backend/base/langflow/schema/graph.py index fafaea1fb011..0cbabf914a49 100644 --- a/src/backend/base/langflow/schema/graph.py +++ b/src/backend/base/langflow/schema/graph.py @@ -37,10 +37,10 @@ class Tweaks(RootModel): def __getitem__(self, key): return self.root[key] - def __setitem__(self, key, value): + def __setitem__(self, key, value) -> None: self.root[key] = value - def __delitem__(self, key): + def __delitem__(self, key) -> None: del self.root[key] def items(self): diff --git a/src/backend/base/langflow/schema/image.py b/src/backend/base/langflow/schema/image.py index 008de8140fb2..0550c837415b 100644 --- a/src/backend/base/langflow/schema/image.py +++ b/src/backend/base/langflow/schema/image.py @@ -8,7 +8,7 @@ IMAGE_ENDPOINT = "/files/images/" -def is_image_file(file_path): +def is_image_file(file_path) -> bool: try: with PILImage.open(file_path) as img: img.verify() # Verify that it is, in fact, an image @@ -61,5 +61,5 @@ def to_content_dict(self): "image_url": self.to_base64(), } - def get_url(self): + def get_url(self) -> str: return f"{IMAGE_ENDPOINT}{self.path}" diff --git a/src/backend/base/langflow/schema/message.py b/src/backend/base/langflow/schema/message.py index fd2a2ab7a60d..e17ecfaf797f 100644 --- a/src/backend/base/langflow/schema/message.py +++ b/src/backend/base/langflow/schema/message.py @@ -94,7 +94,7 @@ def model_post_init(self, __context: Any) -> None: if "timestamp" not in self.data: self.data["timestamp"] = self.timestamp - def set_flow_id(self, flow_id: str): + def set_flow_id(self, flow_id: str) -> None: self.flow_id = flow_id def to_lc_message( diff --git a/src/backend/base/langflow/server.py b/src/backend/base/langflow/server.py index 908d08021d6c..fd8167fb0ad1 100644 --- a/src/backend/base/langflow/server.py +++ b/src/backend/base/langflow/server.py @@ -35,14 +35,14 @@ class Logger(glogging.Logger): gunicorn logs to loguru. """ - def __init__(self, cfg): + def __init__(self, cfg) -> None: super().__init__(cfg) logging.getLogger("gunicorn.error").handlers = [InterceptHandler()] logging.getLogger("gunicorn.access").handlers = [InterceptHandler()] class LangflowApplication(BaseApplication): - def __init__(self, app, options=None): + def __init__(self, app, options=None) -> None: self.options = options or {} self.options["worker_class"] = "langflow.server.LangflowUvicornWorker" @@ -50,7 +50,7 @@ def __init__(self, app, options=None): self.application = app super().__init__() - def load_config(self): + def load_config(self) -> None: config = {key: value for key, value in self.options.items() if key in self.cfg.settings and value is not None} for key, value in config.items(): self.cfg.set(key.lower(), value) diff --git a/src/backend/base/langflow/services/auth/factory.py b/src/backend/base/langflow/services/auth/factory.py index 63d5d2a6d17c..fbf734b3a8bf 100644 --- a/src/backend/base/langflow/services/auth/factory.py +++ b/src/backend/base/langflow/services/auth/factory.py @@ -5,7 +5,7 @@ class AuthServiceFactory(ServiceFactory): name = "auth_service" - def __init__(self): + def __init__(self) -> None: super().__init__(AuthService) def create(self, settings_service): diff --git a/src/backend/base/langflow/services/base.py b/src/backend/base/langflow/services/base.py index f40a62fc53a5..a903332e1591 100644 --- a/src/backend/base/langflow/services/base.py +++ b/src/backend/base/langflow/services/base.py @@ -21,8 +21,8 @@ def get_schema(self): } return schema - async def teardown(self): + async def teardown(self) -> None: return - def set_ready(self): + def set_ready(self) -> None: self.ready = True diff --git a/src/backend/base/langflow/services/cache/base.py b/src/backend/base/langflow/services/cache/base.py index d6b23cc40986..7e1fcee0359a 100644 --- a/src/backend/base/langflow/services/cache/base.py +++ b/src/backend/base/langflow/services/cache/base.py @@ -60,7 +60,7 @@ def clear(self, lock: LockType | None = None): """Clear all items from the cache.""" @abc.abstractmethod - def __contains__(self, key): + def __contains__(self, key) -> bool: """Check if the key is in the cache. Args: @@ -79,7 +79,7 @@ def __getitem__(self, key): """ @abc.abstractmethod - def __setitem__(self, key, value): + def __setitem__(self, key, value) -> None: """Add an item to the cache using the square bracket notation. Args: @@ -88,7 +88,7 @@ def __setitem__(self, key, value): """ @abc.abstractmethod - def __delitem__(self, key): + def __delitem__(self, key) -> None: """Remove an item from the cache using the square bracket notation. Args: @@ -147,7 +147,7 @@ async def clear(self, lock: AsyncLockType | None = None): """Clear all items from the cache.""" @abc.abstractmethod - def __contains__(self, key): + def __contains__(self, key) -> bool: """Check if the key is in the cache. Args: diff --git a/src/backend/base/langflow/services/cache/disk.py b/src/backend/base/langflow/services/cache/disk.py index 9c866d4437c4..ed3412e327fb 100644 --- a/src/backend/base/langflow/services/cache/disk.py +++ b/src/backend/base/langflow/services/cache/disk.py @@ -11,7 +11,7 @@ class AsyncDiskCache(AsyncBaseCacheService, Generic[AsyncLockType]): - def __init__(self, cache_dir, max_size=None, expiration_time=3600): + def __init__(self, cache_dir, max_size=None, expiration_time=3600) -> None: self.cache = Cache(cache_dir) # Let's clear the cache for now to maintain a similar # behavior as the in-memory cache @@ -40,56 +40,56 @@ async def _get(self, key): await self._delete(key) # Log before deleting the expired item return CACHE_MISS - async def set(self, key, value, lock: asyncio.Lock | None = None): + async def set(self, key, value, lock: asyncio.Lock | None = None) -> None: if not lock: async with self.lock: await self._set(key, value) else: await self._set(key, value) - async def _set(self, key, value): + async def _set(self, key, value) -> None: if self.max_size and len(self.cache) >= self.max_size: await asyncio.to_thread(self.cache.cull) item = {"value": pickle.dumps(value) if not isinstance(value, str | bytes) else value, "time": time.time()} await asyncio.to_thread(self.cache.set, key, item) - async def delete(self, key, lock: asyncio.Lock | None = None): + async def delete(self, key, lock: asyncio.Lock | None = None) -> None: if not lock: async with self.lock: await self._delete(key) else: await self._delete(key) - async def _delete(self, key): + async def _delete(self, key) -> None: await asyncio.to_thread(self.cache.delete, key) - async def clear(self, lock: asyncio.Lock | None = None): + async def clear(self, lock: asyncio.Lock | None = None) -> None: if not lock: async with self.lock: await self._clear() else: await self._clear() - async def _clear(self): + async def _clear(self) -> None: await asyncio.to_thread(self.cache.clear) - async def upsert(self, key, value, lock: asyncio.Lock | None = None): + async def upsert(self, key, value, lock: asyncio.Lock | None = None) -> None: if not lock: async with self.lock: await self._upsert(key, value) else: await self._upsert(key, value) - async def _upsert(self, key, value): + async def _upsert(self, key, value) -> None: existing_value = await self.get(key) if existing_value is not CACHE_MISS and isinstance(existing_value, dict) and isinstance(value, dict): existing_value.update(value) value = existing_value await self.set(key, value) - def __contains__(self, key): + def __contains__(self, key) -> bool: return asyncio.run(asyncio.to_thread(self.cache.__contains__, key)) - async def teardown(self): + async def teardown(self) -> None: # Clean up the cache directory self.cache.clear(retry=True) diff --git a/src/backend/base/langflow/services/cache/factory.py b/src/backend/base/langflow/services/cache/factory.py index f9b1eac7e4a7..3def8ebc1757 100644 --- a/src/backend/base/langflow/services/cache/factory.py +++ b/src/backend/base/langflow/services/cache/factory.py @@ -12,7 +12,7 @@ class CacheServiceFactory(ServiceFactory): - def __init__(self): + def __init__(self) -> None: super().__init__(CacheService) def create(self, settings_service: SettingsService): diff --git a/src/backend/base/langflow/services/cache/service.py b/src/backend/base/langflow/services/cache/service.py index 5ddd5752854d..baec3dbd4025 100644 --- a/src/backend/base/langflow/services/cache/service.py +++ b/src/backend/base/langflow/services/cache/service.py @@ -36,14 +36,14 @@ class ThreadingInMemoryCache(CacheService, Generic[LockType]): b = cache["b"] """ - def __init__(self, max_size=None, expiration_time=60 * 60): + def __init__(self, max_size=None, expiration_time=60 * 60) -> None: """Initialize a new InMemoryCache instance. Args: max_size (int, optional): Maximum number of items to store in the cache. expiration_time (int, optional): Time in seconds after which a cached item expires. Default is 1 hour. """ - self._cache = OrderedDict() + self._cache: OrderedDict = OrderedDict() self._lock = threading.RLock() self.max_size = max_size self.expiration_time = expiration_time @@ -72,7 +72,7 @@ def _get_without_lock(self, key): self.delete(key) return None - def set(self, key, value, lock: Union[threading.Lock, None] = None): # noqa: UP007 + def set(self, key, value, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP007 """Add an item to the cache. If the cache is full, the least recently used item is evicted. @@ -93,7 +93,7 @@ def set(self, key, value, lock: Union[threading.Lock, None] = None): # noqa: UP self._cache[key] = {"value": value, "time": time.time()} - def upsert(self, key, value, lock: Union[threading.Lock, None] = None): # noqa: UP007 + def upsert(self, key, value, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP007 """Inserts or updates a value in the cache. If the existing value and the new value are both dictionaries, they are merged. @@ -130,16 +130,16 @@ def get_or_set(self, key, value, lock: Union[threading.Lock, None] = None): # n self.set(key, value) return value - def delete(self, key, lock: Union[threading.Lock, None] = None): # noqa: UP007 + def delete(self, key, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP007 with lock or self._lock: self._cache.pop(key, None) - def clear(self, lock: Union[threading.Lock, None] = None): # noqa: UP007 + def clear(self, lock: Union[threading.Lock, None] = None) -> None: # noqa: UP007 """Clear all items from the cache.""" with lock or self._lock: self._cache.clear() - def __contains__(self, key): + def __contains__(self, key) -> bool: """Check if the key is in the cache.""" return key in self._cache @@ -147,19 +147,19 @@ def __getitem__(self, key): """Retrieve an item from the cache using the square bracket notation.""" return self.get(key) - def __setitem__(self, key, value): + def __setitem__(self, key, value) -> None: """Add an item to the cache using the square bracket notation.""" self.set(key, value) - def __delitem__(self, key): + def __delitem__(self, key) -> None: """Remove an item from the cache using the square bracket notation.""" self.delete(key) - def __len__(self): + def __len__(self) -> int: """Return the number of items in the cache.""" return len(self._cache) - def __repr__(self): + def __repr__(self) -> str: """Return a string representation of the InMemoryCache instance.""" return f"InMemoryCache(max_size={self.max_size}, expiration_time={self.expiration_time})" @@ -185,7 +185,7 @@ class RedisCache(AsyncBaseCacheService, Generic[LockType]): b = cache["b"] """ - def __init__(self, host="localhost", port=6379, db=0, url=None, expiration_time=60 * 60): + def __init__(self, host="localhost", port=6379, db=0, url=None, expiration_time=60 * 60) -> None: """Initialize a new RedisCache instance. Args: @@ -215,7 +215,7 @@ def __init__(self, host="localhost", port=6379, db=0, url=None, expiration_time= self.expiration_time = expiration_time # check connection - def is_connected(self): + def is_connected(self) -> bool: """Check if the Redis client is connected.""" import redis @@ -234,7 +234,7 @@ async def get(self, key, lock=None): return pickle.loads(value) if value else None @override - async def set(self, key, value, lock=None): + async def set(self, key, value, lock=None) -> None: try: if pickled := pickle.dumps(value): result = await self._client.setex(str(key), self.expiration_time, pickled) @@ -246,7 +246,7 @@ async def set(self, key, value, lock=None): raise TypeError(msg) from exc @override - async def upsert(self, key, value, lock=None): + async def upsert(self, key, value, lock=None) -> None: """Inserts or updates a value in the cache. If the existing value and the new value are both dictionaries, they are merged. @@ -266,28 +266,28 @@ async def upsert(self, key, value, lock=None): await self.set(key, value) @override - async def delete(self, key, lock=None): + async def delete(self, key, lock=None) -> None: await self._client.delete(key) @override - async def clear(self, lock=None): + async def clear(self, lock=None) -> None: """Clear all items from the cache.""" await self._client.flushdb() - def __contains__(self, key): + def __contains__(self, key) -> bool: """Check if the key is in the cache.""" if key is None: return False - return asyncio.run(self._client.exists(str(key))) + return bool(asyncio.run(self._client.exists(str(key)))) - def __repr__(self): + def __repr__(self) -> str: """Return a string representation of the RedisCache instance.""" return f"RedisCache(expiration_time={self.expiration_time})" class AsyncInMemoryCache(AsyncBaseCacheService, Generic[AsyncLockType]): - def __init__(self, max_size=None, expiration_time=3600): - self.cache = OrderedDict() + def __init__(self, max_size=None, expiration_time=3600) -> None: + self.cache: OrderedDict = OrderedDict() self.lock = asyncio.Lock() self.max_size = max_size @@ -310,7 +310,7 @@ async def _get(self, key): await self._delete(key) # Log before deleting the expired item return CACHE_MISS - async def set(self, key, value, lock: asyncio.Lock | None = None): + async def set(self, key, value, lock: asyncio.Lock | None = None) -> None: if not lock: async with self.lock: await self._set( @@ -323,46 +323,46 @@ async def set(self, key, value, lock: asyncio.Lock | None = None): value, ) - async def _set(self, key, value): + async def _set(self, key, value) -> None: if self.max_size and len(self.cache) >= self.max_size: self.cache.popitem(last=False) self.cache[key] = {"value": value, "time": time.time()} self.cache.move_to_end(key) - async def delete(self, key, lock: asyncio.Lock | None = None): + async def delete(self, key, lock: asyncio.Lock | None = None) -> None: if not lock: async with self.lock: await self._delete(key) else: await self._delete(key) - async def _delete(self, key): + async def _delete(self, key) -> None: if key in self.cache: del self.cache[key] - async def clear(self, lock: asyncio.Lock | None = None): + async def clear(self, lock: asyncio.Lock | None = None) -> None: if not lock: async with self.lock: await self._clear() else: await self._clear() - async def _clear(self): + async def _clear(self) -> None: self.cache.clear() - async def upsert(self, key, value, lock: asyncio.Lock | None = None): + async def upsert(self, key, value, lock: asyncio.Lock | None = None) -> None: if not lock: async with self.lock: await self._upsert(key, value) else: await self._upsert(key, value) - async def _upsert(self, key, value): + async def _upsert(self, key, value) -> None: existing_value = await self.get(key) if existing_value is not None and isinstance(existing_value, dict) and isinstance(value, dict): existing_value.update(value) value = existing_value await self.set(key, value) - def __contains__(self, key): + def __contains__(self, key) -> bool: return key in self.cache diff --git a/src/backend/base/langflow/services/cache/utils.py b/src/backend/base/langflow/services/cache/utils.py index 919c9868f6b5..5ca3e0ada0e2 100644 --- a/src/backend/base/langflow/services/cache/utils.py +++ b/src/backend/base/langflow/services/cache/utils.py @@ -19,10 +19,10 @@ class CacheMiss: - def __repr__(self): + def __repr__(self) -> str: return "" - def __bool__(self): + def __bool__(self) -> bool: return False @@ -40,7 +40,7 @@ def wrapper(*args, **kwargs): @create_cache_folder -def clear_old_cache_files(max_cache_size: int = 3): +def clear_old_cache_files(max_cache_size: int = 3) -> None: cache_dir = Path(tempfile.gettempdir()) / PREFIX cache_files = list(cache_dir.glob("*.dill")) @@ -155,7 +155,7 @@ def save_uploaded_file(file: UploadFile, folder_name): return file_path -def update_build_status(cache_service, flow_id: str, status: "BuildStatus"): +def update_build_status(cache_service, flow_id: str, status: "BuildStatus") -> None: cached_flow = cache_service[flow_id] if cached_flow is None: msg = f"Flow {flow_id} not found in cache" diff --git a/src/backend/base/langflow/services/chat/cache.py b/src/backend/base/langflow/services/chat/cache.py index 3c6e2a076bb3..8943015e7bf8 100644 --- a/src/backend/base/langflow/services/chat/cache.py +++ b/src/backend/base/langflow/services/chat/cache.py @@ -11,18 +11,18 @@ class Subject: """Base class for implementing the observer pattern.""" - def __init__(self): + def __init__(self) -> None: self.observers: list[Callable[[], None]] = [] - def attach(self, observer: Callable[[], None]): + def attach(self, observer: Callable[[], None]) -> None: """Attach an observer to the subject.""" self.observers.append(observer) - def detach(self, observer: Callable[[], None]): + def detach(self, observer: Callable[[], None]) -> None: """Detach an observer from the subject.""" self.observers.remove(observer) - def notify(self): + def notify(self) -> None: """Notify all observers about an event.""" for observer in self.observers: if observer is None: @@ -33,18 +33,18 @@ def notify(self): class AsyncSubject: """Base class for implementing the async observer pattern.""" - def __init__(self): + def __init__(self) -> None: self.observers: list[Callable[[], Awaitable]] = [] - def attach(self, observer: Callable[[], Awaitable]): + def attach(self, observer: Callable[[], Awaitable]) -> None: """Attach an observer to the subject.""" self.observers.append(observer) - def detach(self, observer: Callable[[], Awaitable]): + def detach(self, observer: Callable[[], Awaitable]) -> None: """Detach an observer from the subject.""" self.observers.remove(observer) - async def notify(self): + async def notify(self) -> None: """Notify all observers about an event.""" for observer in self.observers: if observer is None: @@ -57,11 +57,11 @@ class CacheService(Subject, Service): name = "cache_service" - def __init__(self): + def __init__(self) -> None: super().__init__() - self._cache = {} - self.current_client_id = None - self.current_cache = {} + self._cache: dict[str, Any] = {} + self.current_client_id: str | None = None + self.current_cache: dict[str, Any] = {} @contextmanager def set_client_id(self, client_id: str): @@ -77,9 +77,9 @@ def set_client_id(self, client_id: str): yield finally: self.current_client_id = previous_client_id - self.current_cache = self._cache.get(self.current_client_id, {}) + self.current_cache = self._cache.setdefault(previous_client_id, {}) if previous_client_id else {} - def add(self, name: str, obj: Any, obj_type: str, extension: str | None = None): + def add(self, name: str, obj: Any, obj_type: str, extension: str | None = None) -> None: """Add an object to the current client's cache. Args: @@ -100,7 +100,7 @@ def add(self, name: str, obj: Any, obj_type: str, extension: str | None = None): } self.notify() - def add_pandas(self, name: str, obj: Any): + def add_pandas(self, name: str, obj: Any) -> None: """Add a pandas DataFrame or Series to the current client's cache. Args: @@ -113,7 +113,7 @@ def add_pandas(self, name: str, obj: Any): msg = "Object is not a pandas DataFrame or Series" raise TypeError(msg) - def add_image(self, name: str, obj: Any, extension: str = "png"): + def add_image(self, name: str, obj: Any, extension: str = "png") -> None: """Add a PIL Image to the current client's cache. Args: diff --git a/src/backend/base/langflow/services/chat/factory.py b/src/backend/base/langflow/services/chat/factory.py index 337488e0f444..e554a34dafad 100644 --- a/src/backend/base/langflow/services/chat/factory.py +++ b/src/backend/base/langflow/services/chat/factory.py @@ -3,7 +3,7 @@ class ChatServiceFactory(ServiceFactory): - def __init__(self): + def __init__(self) -> None: super().__init__(ChatService) def create(self): diff --git a/src/backend/base/langflow/services/chat/service.py b/src/backend/base/langflow/services/chat/service.py index 333e6357072a..fc8c3a8be808 100644 --- a/src/backend/base/langflow/services/chat/service.py +++ b/src/backend/base/langflow/services/chat/service.py @@ -13,9 +13,9 @@ class ChatService(Service): name = "chat_service" - def __init__(self): - self._async_cache_locks = defaultdict(asyncio.Lock) - self._sync_cache_locks = defaultdict(RLock) + def __init__(self) -> None: + self._async_cache_locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) + self._sync_cache_locks: dict[str, RLock] = defaultdict(RLock) self.cache_service = get_cache_service() def _get_lock(self, key: str): @@ -101,7 +101,7 @@ async def get_cache(self, key: str, lock: asyncio.Lock | None = None) -> Any: """ return await self._perform_cache_operation("get", key, lock=lock or self._get_lock(key)) - async def clear_cache(self, key: str, lock: asyncio.Lock | None = None): + async def clear_cache(self, key: str, lock: asyncio.Lock | None = None) -> None: """Clear the cache for a client. Args: diff --git a/src/backend/base/langflow/services/database/factory.py b/src/backend/base/langflow/services/database/factory.py index e57959053f41..8e469369ee0d 100644 --- a/src/backend/base/langflow/services/database/factory.py +++ b/src/backend/base/langflow/services/database/factory.py @@ -10,7 +10,7 @@ class DatabaseServiceFactory(ServiceFactory): - def __init__(self): + def __init__(self) -> None: super().__init__(DatabaseService) def create(self, settings_service: SettingsService): diff --git a/src/backend/base/langflow/services/database/models/api_key/model.py b/src/backend/base/langflow/services/database/models/api_key/model.py index f2864e6fad1c..82c7d4ec6591 100644 --- a/src/backend/base/langflow/services/database/models/api_key/model.py +++ b/src/backend/base/langflow/services/database/models/api_key/model.py @@ -59,6 +59,6 @@ class ApiKeyRead(ApiKeyBase): @field_validator("api_key") @classmethod - def mask_api_key(cls, v): + def mask_api_key(cls, v) -> str: # This validator will always run, and will mask the API key return f"{v[:8]}{'*' * (len(v) - 8)}" diff --git a/src/backend/base/langflow/services/database/service.py b/src/backend/base/langflow/services/database/service.py index ec63cd993841..63892797c185 100644 --- a/src/backend/base/langflow/services/database/service.py +++ b/src/backend/base/langflow/services/database/service.py @@ -45,7 +45,7 @@ def __init__(self, settings_service: SettingsService): self.alembic_cfg_path = langflow_dir / "alembic.ini" self.engine = self._create_engine() - def reload_engine(self): + def reload_engine(self) -> None: self.engine = self._create_engine() def _create_engine(self) -> Engine: @@ -82,13 +82,13 @@ def _create_engine(self) -> Engine: msg = "Error creating database engine" raise RuntimeError(msg) from exc - def on_connection(self, dbapi_connection, _connection_record): + def on_connection(self, dbapi_connection, _connection_record) -> None: from sqlite3 import Connection as sqliteConnection if isinstance(dbapi_connection, sqliteConnection): - pragmas: dict | None = self.settings_service.settings.sqlite_pragmas + pragmas: dict = self.settings_service.settings.sqlite_pragmas or {} pragmas_list = [] - for key, val in pragmas.items() or {}: + for key, val in pragmas.items(): pragmas_list.append(f"PRAGMA {key} = {val}") logger.info(f"sqlite connection, setting pragmas: {pragmas_list}") if pragmas_list: @@ -107,7 +107,7 @@ def with_session(self): with Session(self.engine) as session: yield session - def migrate_flows_if_auto_login(self): + def migrate_flows_if_auto_login(self) -> None: # if auto_login is enabled, we need to migrate the flows # to the default superuser if they don't have a user id # associated with them @@ -161,14 +161,14 @@ def check_schema_health(self) -> bool: return True - def init_alembic(self, alembic_cfg): + def init_alembic(self, alembic_cfg) -> None: logger.info("Initializing alembic") command.ensure_version(alembic_cfg) # alembic_cfg.attributes["connection"].commit() command.upgrade(alembic_cfg, "head") logger.info("Alembic initialized") - def run_migrations(self, *, fix=False): + def run_migrations(self, *, fix=False) -> None: # First we need to check if alembic has been initialized # If not, we need to initialize it # if not self.script_location.exists(): # this is not the correct way to check if alembic has been initialized @@ -227,7 +227,7 @@ def run_migrations(self, *, fix=False): if fix: self.try_downgrade_upgrade_until_success(alembic_cfg) - def try_downgrade_upgrade_until_success(self, alembic_cfg, retries=5): + def try_downgrade_upgrade_until_success(self, alembic_cfg, retries=5) -> None: # Try -1 then head, if it fails, try -2 then head, etc. # until we reach the number of retries for i in range(1, retries + 1): @@ -273,7 +273,7 @@ def check_table(self, model): results.append(Result(name=column, type="column", success=True)) return results - def create_db_and_tables(self): + def create_db_and_tables(self) -> None: from sqlalchemy import inspect inspector = inspect(self.engine) @@ -308,7 +308,7 @@ def create_db_and_tables(self): logger.debug("Database and tables created successfully") - async def teardown(self): + async def teardown(self) -> None: logger.debug("Tearing down database") try: settings_service = get_settings_service() diff --git a/src/backend/base/langflow/services/database/utils.py b/src/backend/base/langflow/services/database/utils.py index a5065e892fc7..7337c5778d86 100644 --- a/src/backend/base/langflow/services/database/utils.py +++ b/src/backend/base/langflow/services/database/utils.py @@ -12,7 +12,7 @@ from langflow.services.database.service import DatabaseService -def initialize_database(*, fix_migration: bool = False): +def initialize_database(*, fix_migration: bool = False) -> None: logger.debug("Initializing database") from langflow.services.deps import get_db_service diff --git a/src/backend/base/langflow/services/factory.py b/src/backend/base/langflow/services/factory.py index 84800ac11c24..40146fd5e2bb 100644 --- a/src/backend/base/langflow/services/factory.py +++ b/src/backend/base/langflow/services/factory.py @@ -15,7 +15,7 @@ class ServiceFactory: def __init__( self, service_class, - ): + ) -> None: self.service_class = service_class self.dependencies = infer_service_types(self, import_all_services_into_a_dict()) @@ -23,7 +23,7 @@ def create(self, *args, **kwargs) -> "Service": raise self.service_class(*args, **kwargs) -def hash_factory(factory: type[ServiceFactory]) -> str: +def hash_factory(factory: ServiceFactory) -> str: return factory.service_class.__name__ @@ -31,15 +31,15 @@ def hash_dict(d: dict) -> str: return str(d) -def hash_infer_service_types_args(factory_class: type[ServiceFactory], available_services=None) -> str: - factory_hash = hash_factory(factory_class) +def hash_infer_service_types_args(factory: ServiceFactory, available_services=None) -> str: + factory_hash = hash_factory(factory) services_hash = hash_dict(available_services) return f"{factory_hash}_{services_hash}" @cached(cache=LRUCache(maxsize=10), key=hash_infer_service_types_args) -def infer_service_types(factory_class: type[ServiceFactory], available_services=None) -> list["ServiceType"]: - create_method = factory_class.create +def infer_service_types(factory: ServiceFactory, available_services=None) -> list["ServiceType"]: + create_method = factory.create type_hints = get_type_hints(create_method, globalns=available_services) service_types = [] for param_name, param_type in type_hints.items(): diff --git a/src/backend/base/langflow/services/manager.py b/src/backend/base/langflow/services/manager.py index 8acb6d1a0492..458c7a66ea3b 100644 --- a/src/backend/base/langflow/services/manager.py +++ b/src/backend/base/langflow/services/manager.py @@ -22,13 +22,13 @@ class NoFactoryRegisteredError(Exception): class ServiceManager: """Manages the creation of different services.""" - def __init__(self): + def __init__(self) -> None: self.services: dict[str, Service] = {} - self.factories = {} + self.factories: dict[str, ServiceFactory] = {} self.register_factories() self.keyed_lock = KeyedMemoryLockManager() - def register_factories(self): + def register_factories(self) -> None: for factory in self.get_factories(): try: self.register_factory(factory) @@ -38,7 +38,7 @@ def register_factories(self): def register_factory( self, service_factory: ServiceFactory, - ): + ) -> None: """Registers a new factory with dependencies.""" service_name = service_factory.service_class.name self.factories[service_name] = service_factory @@ -51,7 +51,7 @@ def get(self, service_name: ServiceType, default: ServiceFactory | None = None) return self.services[service_name] - def _create_service(self, service_name: ServiceType, default: ServiceFactory | None = None): + def _create_service(self, service_name: ServiceType, default: ServiceFactory | None = None) -> None: """Create a new service given its name, handling dependencies.""" logger.debug(f"Create service {service_name}") self._validate_service_creation(service_name, default) @@ -61,6 +61,9 @@ def _create_service(self, service_name: ServiceType, default: ServiceFactory | N if factory is None and default is not None: self.register_factory(default) factory = default + if factory is None: + msg = f"No factory registered for {service_name}" + raise NoFactoryRegisteredError(msg) for dependency in factory.dependencies: if dependency not in self.services: self._create_service(dependency) @@ -72,20 +75,20 @@ def _create_service(self, service_name: ServiceType, default: ServiceFactory | N self.services[service_name] = self.factories[service_name].create(**dependent_services) self.services[service_name].set_ready() - def _validate_service_creation(self, service_name: ServiceType, default: ServiceFactory | None = None): + def _validate_service_creation(self, service_name: ServiceType, default: ServiceFactory | None = None) -> None: """Validate whether the service can be created.""" if service_name not in self.factories and default is None: msg = f"No factory registered for the service class '{service_name.name}'" raise NoFactoryRegisteredError(msg) - def update(self, service_name: ServiceType): + def update(self, service_name: ServiceType) -> None: """Update a service by its name.""" if service_name in self.services: logger.debug(f"Update service {service_name}") self.services.pop(service_name, None) self.get(service_name) - async def teardown(self): + async def teardown(self) -> None: """Teardown all the services.""" for service in self.services.values(): if service is None: @@ -131,14 +134,14 @@ def get_factories(): service_manager = ServiceManager() -def initialize_settings_service(): +def initialize_settings_service() -> None: """Initialize the settings manager.""" from langflow.services.settings import factory as settings_factory service_manager.register_factory(settings_factory.SettingsServiceFactory()) -def initialize_session_service(): +def initialize_session_service() -> None: """Initialize the session manager.""" from langflow.services.cache import factory as cache_factory from langflow.services.session import factory as session_service_factory diff --git a/src/backend/base/langflow/services/plugins/base.py b/src/backend/base/langflow/services/plugins/base.py index a5ab14e5460b..5f20db136c10 100644 --- a/src/backend/base/langflow/services/plugins/base.py +++ b/src/backend/base/langflow/services/plugins/base.py @@ -2,10 +2,10 @@ class BasePlugin: - def initialize(self): + def initialize(self) -> None: pass - def teardown(self): + def teardown(self) -> None: pass def get(self) -> Any: diff --git a/src/backend/base/langflow/services/plugins/factory.py b/src/backend/base/langflow/services/plugins/factory.py index 72e541559755..e6048d5508eb 100644 --- a/src/backend/base/langflow/services/plugins/factory.py +++ b/src/backend/base/langflow/services/plugins/factory.py @@ -5,7 +5,7 @@ class PluginServiceFactory(ServiceFactory): - def __init__(self): + def __init__(self) -> None: super().__init__(PluginService) def create(self): diff --git a/src/backend/base/langflow/services/plugins/langfuse_plugin.py b/src/backend/base/langflow/services/plugins/langfuse_plugin.py index 2f69f6ea2fea..fe7389ac167b 100644 --- a/src/backend/base/langflow/services/plugins/langfuse_plugin.py +++ b/src/backend/base/langflow/services/plugins/langfuse_plugin.py @@ -22,7 +22,7 @@ def get(cls): return cls._instance @classmethod - def create(cls): + def create(cls) -> None: try: logger.debug("Creating Langfuse instance") from langfuse import Langfuse @@ -44,13 +44,13 @@ def create(cls): cls._instance = None @classmethod - def update(cls): + def update(cls) -> None: logger.debug("Updating Langfuse instance") cls._instance = None cls.create() @classmethod - def teardown(cls): + def teardown(cls) -> None: logger.debug("Tearing down Langfuse instance") if cls._instance is not None: cls._instance.flush() @@ -58,10 +58,10 @@ def teardown(cls): class LangfusePlugin(CallbackPlugin): - def initialize(self): + def initialize(self) -> None: LangfuseInstance.create() - def teardown(self): + def teardown(self) -> None: LangfuseInstance.teardown() def get(self): diff --git a/src/backend/base/langflow/services/plugins/service.py b/src/backend/base/langflow/services/plugins/service.py index ad9b11f35df2..c623cfdd6d78 100644 --- a/src/backend/base/langflow/services/plugins/service.py +++ b/src/backend/base/langflow/services/plugins/service.py @@ -13,13 +13,13 @@ class PluginService(Service): name = "plugin_service" - def __init__(self): + def __init__(self) -> None: self.plugins: dict[str, BasePlugin] = {} self.plugin_dir = Path(__file__).parent self.plugins_base_module = "langflow.services.plugins" self.load_plugins() - def load_plugins(self): + def load_plugins(self) -> None: base_files = ["base.py", "service.py", "factory.py", "__init__.py"] for module in self.plugin_dir.iterdir(): if module.suffix == ".py" and module.name not in base_files: @@ -38,7 +38,7 @@ def load_plugins(self): except Exception: # noqa: BLE001 logger.exception(f"Error loading plugin {plugin_name}") - def register_plugin(self, plugin_name, plugin_instance): + def register_plugin(self, plugin_name, plugin_instance) -> None: self.plugins[plugin_name] = plugin_instance plugin_instance.initialize() @@ -50,7 +50,7 @@ def get(self, plugin_name): return plugin.get() return None - async def teardown(self): + async def teardown(self) -> None: for plugin in self.plugins.values(): plugin.teardown() diff --git a/src/backend/base/langflow/services/session/factory.py b/src/backend/base/langflow/services/session/factory.py index d55bd5b46850..806067529fe6 100644 --- a/src/backend/base/langflow/services/session/factory.py +++ b/src/backend/base/langflow/services/session/factory.py @@ -8,7 +8,7 @@ class SessionServiceFactory(ServiceFactory): - def __init__(self): + def __init__(self) -> None: super().__init__(SessionService) def create(self, cache_service: "CacheService"): diff --git a/src/backend/base/langflow/services/session/service.py b/src/backend/base/langflow/services/session/service.py index 0997b7912392..abd4e630a78f 100644 --- a/src/backend/base/langflow/services/session/service.py +++ b/src/backend/base/langflow/services/session/service.py @@ -11,7 +11,7 @@ class SessionService(Service): name = "session_service" - def __init__(self, cache_service): + def __init__(self, cache_service) -> None: self.cache_service: CacheService = cache_service async def load_session(self, key, flow_id: str, data_graph: dict | None = None): @@ -35,7 +35,7 @@ async def load_session(self, key, flow_id: str, data_graph: dict | None = None): return graph, artifacts - def build_key(self, session_id, data_graph): + def build_key(self, session_id, data_graph) -> str: json_hash = compute_dict_hash(data_graph) return f"{session_id}{':' if session_id else ''}{json_hash}" @@ -46,13 +46,13 @@ def generate_key(self, session_id, data_graph): session_id = session_id_generator() return self.build_key(session_id, data_graph=data_graph) - async def update_session(self, session_id, value): + async def update_session(self, session_id, value) -> None: result = self.cache_service.set(session_id, value) # if it is a coroutine, await it if isinstance(result, Coroutine): await result - async def clear_session(self, session_id): + async def clear_session(self, session_id) -> None: result = self.cache_service.delete(session_id) # if it is a coroutine, await it if isinstance(result, Coroutine): diff --git a/src/backend/base/langflow/services/settings/auth.py b/src/backend/base/langflow/services/settings/auth.py index b5f36632425f..045dd155d989 100644 --- a/src/backend/base/langflow/services/settings/auth.py +++ b/src/backend/base/langflow/services/settings/auth.py @@ -57,7 +57,7 @@ class Config: extra = "ignore" env_prefix = "LANGFLOW_" - def reset_credentials(self): + def reset_credentials(self) -> None: self.SUPERUSER = DEFAULT_SUPERUSER self.SUPERUSER_PASSWORD = DEFAULT_SUPERUSER_PASSWORD diff --git a/src/backend/base/langflow/services/settings/base.py b/src/backend/base/langflow/services/settings/base.py index c1d1d5d0d405..cdbdccde317d 100644 --- a/src/backend/base/langflow/services/settings/base.py +++ b/src/backend/base/langflow/services/settings/base.py @@ -326,12 +326,12 @@ def set_components_path(cls, value): model_config = SettingsConfigDict(validate_assignment=True, extra="ignore", env_prefix="LANGFLOW_") - def update_from_yaml(self, file_path: str, *, dev: bool = False): + def update_from_yaml(self, file_path: str, *, dev: bool = False) -> None: new_settings = load_settings_from_yaml(file_path) self.components_path = new_settings.components_path or [] self.dev = dev - def update_settings(self, **kwargs): + def update_settings(self, **kwargs) -> None: logger.debug("Updating settings") for key, value in kwargs.items(): # value may contain sensitive information, so we don't want to log it @@ -374,7 +374,7 @@ def settings_customise_sources( # type: ignore[misc] return (MyCustomSource(settings_cls),) -def save_settings_to_yaml(settings: Settings, file_path: str): +def save_settings_to_yaml(settings: Settings, file_path: str) -> None: with Path(file_path).open("w", encoding="utf-8") as f: settings_dict = settings.model_dump() yaml.dump(settings_dict, f) diff --git a/src/backend/base/langflow/services/settings/factory.py b/src/backend/base/langflow/services/settings/factory.py index 008982f2b1e9..30f3ca9c3b02 100644 --- a/src/backend/base/langflow/services/settings/factory.py +++ b/src/backend/base/langflow/services/settings/factory.py @@ -10,7 +10,7 @@ def __new__(cls): cls._instance = super().__new__(cls) return cls._instance - def __init__(self): + def __init__(self) -> None: super().__init__(SettingsService) def create(self): diff --git a/src/backend/base/langflow/services/settings/utils.py b/src/backend/base/langflow/services/settings/utils.py index 0ce45f9e141d..b280444df70b 100644 --- a/src/backend/base/langflow/services/settings/utils.py +++ b/src/backend/base/langflow/services/settings/utils.py @@ -4,7 +4,7 @@ from loguru import logger -def set_secure_permissions(file_path: Path): +def set_secure_permissions(file_path: Path) -> None: if platform.system() in {"Linux", "Darwin"}: # Unix/Linux/Mac file_path.chmod(0o600) elif platform.system() == "Windows": diff --git a/src/backend/base/langflow/services/shared_component_cache/factory.py b/src/backend/base/langflow/services/shared_component_cache/factory.py index ddb51a0495cb..c9c464967352 100644 --- a/src/backend/base/langflow/services/shared_component_cache/factory.py +++ b/src/backend/base/langflow/services/shared_component_cache/factory.py @@ -8,7 +8,7 @@ class SharedComponentCacheServiceFactory(ServiceFactory): - def __init__(self): + def __init__(self) -> None: super().__init__(SharedComponentCacheService) def create(self, settings_service: "SettingsService"): diff --git a/src/backend/base/langflow/services/socket/factory.py b/src/backend/base/langflow/services/socket/factory.py index 3ea6bb0ba0cd..68fe8337c8f5 100644 --- a/src/backend/base/langflow/services/socket/factory.py +++ b/src/backend/base/langflow/services/socket/factory.py @@ -8,7 +8,7 @@ class SocketIOFactory(ServiceFactory): - def __init__(self): + def __init__(self) -> None: super().__init__( service_class=SocketIOService, ) diff --git a/src/backend/base/langflow/services/socket/service.py b/src/backend/base/langflow/services/socket/service.py index d2385cf0a76b..d4ec60579fb9 100644 --- a/src/backend/base/langflow/services/socket/service.py +++ b/src/backend/base/langflow/services/socket/service.py @@ -17,7 +17,7 @@ class SocketIOService(Service): def __init__(self, cache_service: "CacheService"): self.cache_service = cache_service - def init(self, sio: socketio.AsyncServer): + def init(self, sio: socketio.AsyncServer) -> None: # Registering event handlers self.sio = sio if self.sio: @@ -28,32 +28,32 @@ def init(self, sio: socketio.AsyncServer): self.sio.on("build_vertex")(self.on_build_vertex) self.sessions = {} # type: dict[str, dict] - async def emit_error(self, sid, error): + async def emit_error(self, sid, error) -> None: await self.sio.emit("error", to=sid, data=error) - async def connect(self, sid, environ): + async def connect(self, sid, environ) -> None: logger.info(f"Socket connected: {sid}") self.sessions[sid] = environ - async def disconnect(self, sid): + async def disconnect(self, sid) -> None: logger.info(f"Socket disconnected: {sid}") self.sessions.pop(sid, None) - async def message(self, sid, data=None): + async def message(self, sid, data=None) -> None: # Logic for handling messages await self.emit_message(to=sid, data=data or {"foo": "bar", "baz": [1, 2, 3]}) - async def emit_message(self, to, data): + async def emit_message(self, to, data) -> None: # Abstracting sio.emit await self.sio.emit("message", to=to, data=data) - async def emit_token(self, to, data): + async def emit_token(self, to, data) -> None: await self.sio.emit("token", to=to, data=data) - async def on_get_vertices(self, sid, flow_id): + async def on_get_vertices(self, sid, flow_id) -> None: await get_vertices(self.sio, sid, flow_id, get_chat_service()) - async def on_build_vertex(self, sid, flow_id, vertex_id): + async def on_build_vertex(self, sid, flow_id, vertex_id) -> None: await build_vertex( sio=self.sio, sid=sid, diff --git a/src/backend/base/langflow/services/socket/utils.py b/src/backend/base/langflow/services/socket/utils.py index ae96bc06013b..b79d48993449 100644 --- a/src/backend/base/langflow/services/socket/utils.py +++ b/src/backend/base/langflow/services/socket/utils.py @@ -14,16 +14,16 @@ from langflow.services.deps import get_session -def set_socketio_server(socketio_server): +def set_socketio_server(socketio_server) -> None: from langflow.services.deps import get_socket_service socket_service = get_socket_service() socket_service.init(socketio_server) -async def get_vertices(sio, sid, flow_id, chat_service): +async def get_vertices(sio, sid, flow_id, chat_service) -> None: try: - session = get_session() + session = next(get_session()) flow: Flow = session.exec(select(Flow).where(Flow.id == flow_id)).first() if not flow or not flow.data: await sio.emit("error", data="Invalid flow ID", to=sid) @@ -31,7 +31,7 @@ async def get_vertices(sio, sid, flow_id, chat_service): graph = Graph.from_payload(flow.data) chat_service.set_cache(flow_id, graph) - vertices = graph.layered_topological_sort() + vertices = graph.layered_topological_sort(graph.vertices) # Emit the vertices to the client await sio.emit("vertices_order", data=vertices, to=sid) @@ -48,7 +48,7 @@ async def build_vertex( vertex_id: str, get_cache: Callable, set_cache: Callable, -): +) -> None: try: cache = get_cache(flow_id) graph = cache.get("result") diff --git a/src/backend/base/langflow/services/state/factory.py b/src/backend/base/langflow/services/state/factory.py index e6c5ee740e08..350d7bdcdaf9 100644 --- a/src/backend/base/langflow/services/state/factory.py +++ b/src/backend/base/langflow/services/state/factory.py @@ -4,7 +4,7 @@ class StateServiceFactory(ServiceFactory): - def __init__(self): + def __init__(self) -> None: super().__init__(InMemoryStateService) def create(self, settings_service: SettingsService): diff --git a/src/backend/base/langflow/services/state/service.py b/src/backend/base/langflow/services/state/service.py index cb61f29e822a..100b5442da7b 100644 --- a/src/backend/base/langflow/services/state/service.py +++ b/src/backend/base/langflow/services/state/service.py @@ -11,19 +11,19 @@ class StateService(Service): name = "state_service" - def append_state(self, key, new_state, run_id: str): + def append_state(self, key, new_state, run_id: str) -> None: raise NotImplementedError - def update_state(self, key, new_state, run_id: str): + def update_state(self, key, new_state, run_id: str) -> None: raise NotImplementedError def get_state(self, key, run_id: str): raise NotImplementedError - def subscribe(self, key, observer: Callable): + def subscribe(self, key, observer: Callable) -> None: raise NotImplementedError - def notify_observers(self, key, new_state): + def notify_observers(self, key, new_state) -> None: raise NotImplementedError @@ -34,7 +34,7 @@ def __init__(self, settings_service: SettingsService): self.observers: dict = defaultdict(list) self.lock = Lock() - def append_state(self, key, new_state, run_id: str): + def append_state(self, key, new_state, run_id: str) -> None: with self.lock: if run_id not in self.states: self.states[run_id] = {} @@ -45,7 +45,7 @@ def append_state(self, key, new_state, run_id: str): self.states[run_id][key].append(new_state) self.notify_append_observers(key, new_state) - def update_state(self, key, new_state, run_id: str): + def update_state(self, key, new_state, run_id: str) -> None: with self.lock: if run_id not in self.states: self.states[run_id] = {} @@ -56,16 +56,16 @@ def get_state(self, key, run_id: str): with self.lock: return self.states.get(run_id, {}).get(key, "") - def subscribe(self, key, observer: Callable): + def subscribe(self, key, observer: Callable) -> None: with self.lock: if observer not in self.observers[key]: self.observers[key].append(observer) - def notify_observers(self, key, new_state): + def notify_observers(self, key, new_state) -> None: for callback in self.observers[key]: callback(key, new_state, append=False) - def notify_append_observers(self, key, new_state): + def notify_append_observers(self, key, new_state) -> None: for callback in self.observers[key]: try: callback(key, new_state, append=True) diff --git a/src/backend/base/langflow/services/storage/factory.py b/src/backend/base/langflow/services/storage/factory.py index 335772657ccd..f42a84ece07c 100644 --- a/src/backend/base/langflow/services/storage/factory.py +++ b/src/backend/base/langflow/services/storage/factory.py @@ -7,7 +7,7 @@ class StorageServiceFactory(ServiceFactory): - def __init__(self): + def __init__(self) -> None: super().__init__( StorageService, ) diff --git a/src/backend/base/langflow/services/storage/local.py b/src/backend/base/langflow/services/storage/local.py index 36304aec9bfe..4748a26a2c83 100644 --- a/src/backend/base/langflow/services/storage/local.py +++ b/src/backend/base/langflow/services/storage/local.py @@ -9,7 +9,7 @@ class LocalStorageService(StorageService): """A service class for handling local storage operations without aiofiles.""" - def __init__(self, session_service, settings_service): + def __init__(self, session_service, settings_service) -> None: """Initialize the local storage service with session and settings services.""" super().__init__(session_service, settings_service) self.data_dir = Path(settings_service.settings.config_dir) @@ -19,7 +19,7 @@ def build_full_path(self, flow_id: str, file_name: str) -> str: """Build the full path of a file in the local storage.""" return str(self.data_dir / flow_id / file_name) - async def save_file(self, flow_id: str, file_name: str, data: bytes): + async def save_file(self, flow_id: str, file_name: str, data: bytes) -> None: """Save a file in the local storage. :param flow_id: The identifier for the flow. @@ -81,7 +81,7 @@ async def list_files(self, flow_id: str): logger.info(f"Listed {len(files)} files in flow {flow_id}.") return files - async def delete_file(self, flow_id: str, file_name: str): + async def delete_file(self, flow_id: str, file_name: str) -> None: """Delete a file from the local storage. :param flow_id: The identifier for the flow. @@ -94,6 +94,6 @@ async def delete_file(self, flow_id: str, file_name: str): else: logger.warning(f"Attempted to delete non-existent file {file_name} in flow {flow_id}.") - async def teardown(self): + async def teardown(self) -> None: """Perform any cleanup operations when the service is being torn down.""" # No specific teardown actions required for local diff --git a/src/backend/base/langflow/services/storage/s3.py b/src/backend/base/langflow/services/storage/s3.py index f996c1035b54..46ce643c3f1f 100644 --- a/src/backend/base/langflow/services/storage/s3.py +++ b/src/backend/base/langflow/services/storage/s3.py @@ -8,14 +8,14 @@ class S3StorageService(StorageService): """A service class for handling operations with AWS S3 storage.""" - async def __init__(self, session_service, settings_service): + def __init__(self, session_service, settings_service) -> None: """Initialize the S3 storage service with session and settings services.""" super().__init__(session_service, settings_service) self.bucket = "langflow" self.s3_client = boto3.client("s3") self.set_ready() - async def save_file(self, folder: str, file_name: str, data): + async def save_file(self, folder: str, file_name: str, data) -> None: """Save a file to the S3 bucket. :param folder: The folder in the bucket to save the file. @@ -66,7 +66,7 @@ async def list_files(self, folder: str): logger.info(f"{len(files)} files listed in folder {folder}.") return files - async def delete_file(self, folder: str, file_name: str): + async def delete_file(self, folder: str, file_name: str) -> None: """Delete a file from the S3 bucket. :param folder: The folder in the bucket where the file is stored. @@ -80,6 +80,6 @@ async def delete_file(self, folder: str, file_name: str): logger.exception(f"Error deleting file {file_name} from folder {folder}") raise - async def teardown(self): + async def teardown(self) -> None: """Perform any cleanup operations when the service is being torn down.""" # No specific teardown actions required for S3 storage at the moment. diff --git a/src/backend/base/langflow/services/storage/service.py b/src/backend/base/langflow/services/storage/service.py index 62eff6ab0ee3..e139b73c4753 100644 --- a/src/backend/base/langflow/services/storage/service.py +++ b/src/backend/base/langflow/services/storage/service.py @@ -21,7 +21,7 @@ def __init__(self, session_service: SessionService, settings_service: SettingsSe def build_full_path(self, flow_id: str, file_name: str) -> str: raise NotImplementedError - def set_ready(self): + def set_ready(self) -> None: self.ready = True @abstractmethod @@ -37,8 +37,8 @@ async def list_files(self, flow_id: str) -> list[str]: raise NotImplementedError @abstractmethod - async def delete_file(self, flow_id: str, file_name: str) -> bool: + async def delete_file(self, flow_id: str, file_name: str) -> None: raise NotImplementedError - async def teardown(self): + async def teardown(self) -> None: raise NotImplementedError diff --git a/src/backend/base/langflow/services/store/exceptions.py b/src/backend/base/langflow/services/store/exceptions.py index 60ee08b49b66..1c00c5a50cc1 100644 --- a/src/backend/base/langflow/services/store/exceptions.py +++ b/src/backend/base/langflow/services/store/exceptions.py @@ -1,25 +1,25 @@ class CustomError(Exception): - def __init__(self, detail, status_code): + def __init__(self, detail: str, status_code: int): super().__init__(detail) self.status_code = status_code # Define custom exceptions with status codes class UnauthorizedError(CustomError): - def __init__(self, detail="Unauthorized access"): + def __init__(self, detail: str = "Unauthorized access"): super().__init__(detail, 401) class ForbiddenError(CustomError): - def __init__(self, detail="Forbidden"): + def __init__(self, detail: str = "Forbidden"): super().__init__(detail, 403) class APIKeyError(CustomError): - def __init__(self, detail="API key error"): + def __init__(self, detail: str = "API key error"): super().__init__(detail, 400) # ! Should be 401 class FilterError(CustomError): - def __init__(self, detail="Filter error"): + def __init__(self, detail: str = "Filter error"): super().__init__(detail, 400) diff --git a/src/backend/base/langflow/services/store/factory.py b/src/backend/base/langflow/services/store/factory.py index 3b668a226cba..0a4f18c4a2d5 100644 --- a/src/backend/base/langflow/services/store/factory.py +++ b/src/backend/base/langflow/services/store/factory.py @@ -10,7 +10,7 @@ class StoreServiceFactory(ServiceFactory): - def __init__(self): + def __init__(self) -> None: super().__init__(StoreService) def create(self, settings_service: SettingsService): diff --git a/src/backend/base/langflow/services/task/backends/anyio.py b/src/backend/base/langflow/services/task/backends/anyio.py index 15ba520a9dfb..8f167283fb12 100644 --- a/src/backend/base/langflow/services/task/backends/anyio.py +++ b/src/backend/base/langflow/services/task/backends/anyio.py @@ -9,11 +9,11 @@ class AnyIOTaskResult: - def __init__(self, scope): + def __init__(self, scope) -> None: self._scope = scope self._status = "PENDING" self._result = None - self._exception = None + self._exception: Exception | None = None @property def status(self) -> str: @@ -34,7 +34,7 @@ def result(self) -> Any: def ready(self) -> bool: return self._status == "DONE" - async def run(self, func, *args, **kwargs): + async def run(self, func, *args, **kwargs) -> None: try: self._result = await func(*args, **kwargs) except Exception as e: # noqa: BLE001 @@ -47,8 +47,8 @@ async def run(self, func, *args, **kwargs): class AnyIOBackend(TaskBackend): name = "anyio" - def __init__(self): - self.tasks = {} + def __init__(self) -> None: + self.tasks: dict[str, AnyIOTaskResult] = {} async def launch_task( self, task_func: Callable[..., Any], *args: Any, **kwargs: Any diff --git a/src/backend/base/langflow/services/task/backends/celery.py b/src/backend/base/langflow/services/task/backends/celery.py index e5dbb3e50684..d26725e7676e 100644 --- a/src/backend/base/langflow/services/task/backends/celery.py +++ b/src/backend/base/langflow/services/task/backends/celery.py @@ -13,7 +13,7 @@ class CeleryBackend(TaskBackend): name = "celery" - def __init__(self): + def __init__(self) -> None: self.celery_app = celery_app def launch_task(self, task_func: Callable[..., Any], *args: Any, **kwargs: Any) -> tuple[str, AsyncResult]: diff --git a/src/backend/base/langflow/services/task/factory.py b/src/backend/base/langflow/services/task/factory.py index 937f390ae079..c030776de5ea 100644 --- a/src/backend/base/langflow/services/task/factory.py +++ b/src/backend/base/langflow/services/task/factory.py @@ -3,7 +3,7 @@ class TaskServiceFactory(ServiceFactory): - def __init__(self): + def __init__(self) -> None: super().__init__(TaskService) def create(self): diff --git a/src/backend/base/langflow/services/telemetry/factory.py b/src/backend/base/langflow/services/telemetry/factory.py index 6368377a4c33..1fb087de7c39 100644 --- a/src/backend/base/langflow/services/telemetry/factory.py +++ b/src/backend/base/langflow/services/telemetry/factory.py @@ -10,7 +10,7 @@ class TelemetryServiceFactory(ServiceFactory): - def __init__(self): + def __init__(self) -> None: super().__init__(TelemetryService) def create(self, settings_service: SettingsService): diff --git a/src/backend/base/langflow/services/telemetry/opentelemetry.py b/src/backend/base/langflow/services/telemetry/opentelemetry.py index 82d5406d9324..4fe30d6756ba 100644 --- a/src/backend/base/langflow/services/telemetry/opentelemetry.py +++ b/src/backend/base/langflow/services/telemetry/opentelemetry.py @@ -55,7 +55,7 @@ def _callback(self, _options: CallbackOptions): # return [Observation(self._value)] - def set_value(self, value: float, labels: Mapping[str, str]): + def set_value(self, value: float, labels: Mapping[str, str]) -> None: self._values[tuple(sorted(labels.items()))] = value @@ -76,7 +76,7 @@ def __init__( self.mandatory_labels = [label for label, required in labels.items() if required] self.allowed_labels = list(labels.keys()) - def validate_labels(self, labels: Mapping[str, str]): + def validate_labels(self, labels: Mapping[str, str]) -> None: """Validate if the labels provided are valid.""" if labels is None or len(labels) == 0: msg = "Labels must be provided for the metric" @@ -87,7 +87,7 @@ def validate_labels(self, labels: Mapping[str, str]): msg = f"Missing required labels: {missing_labels}" raise ValueError(msg) - def __repr__(self): + def __repr__(self) -> str: return f"Metric(name='{self.name}', description='{self.description}', type={self.type}, unit='{self.unit}')" @@ -111,14 +111,16 @@ class OpenTelemetry(metaclass=ThreadSafeSingletonMetaUsingWeakref): _metrics: dict[str, Counter | ObservableGaugeWrapper | Histogram | UpDownCounter] = {} _meter_provider: MeterProvider | None = None - def _add_metric(self, name: str, description: str, unit: str, metric_type: MetricType, labels: dict[str, bool]): + def _add_metric( + self, name: str, description: str, unit: str, metric_type: MetricType, labels: dict[str, bool] + ) -> None: metric = Metric(name=name, description=description, metric_type=metric_type, unit=unit, labels=labels) self._metrics_registry[name] = metric if labels is None or len(labels) == 0: msg = "Labels must be provided for the metric upon registration" raise ValueError(msg) - def _register_metric(self): + def _register_metric(self) -> None: """Define any custom metrics here. A thread safe singleton class to manage metrics. @@ -196,14 +198,14 @@ def _create_metric(self, metric): msg = f"Unknown metric type: {metric.type}" raise ValueError(msg) - def validate_labels(self, metric_name: str, labels: Mapping[str, str]): + def validate_labels(self, metric_name: str, labels: Mapping[str, str]) -> None: reg = self._metrics_registry.get(metric_name) if reg is None: msg = f"Metric '{metric_name}' is not registered" raise ValueError(msg) reg.validate_labels(labels) - def increment_counter(self, metric_name: str, labels: Mapping[str, str], value: float = 1.0): + def increment_counter(self, metric_name: str, labels: Mapping[str, str], value: float = 1.0) -> None: self.validate_labels(metric_name, labels) counter = self._metrics.get(metric_name) if isinstance(counter, Counter): @@ -212,7 +214,7 @@ def increment_counter(self, metric_name: str, labels: Mapping[str, str], value: msg = f"Metric '{metric_name}' is not a counter" raise TypeError(msg) - def up_down_counter(self, metric_name: str, value: float, labels: Mapping[str, str]): + def up_down_counter(self, metric_name: str, value: float, labels: Mapping[str, str]) -> None: self.validate_labels(metric_name, labels) up_down_counter = self._metrics.get(metric_name) if isinstance(up_down_counter, UpDownCounter): @@ -221,7 +223,7 @@ def up_down_counter(self, metric_name: str, value: float, labels: Mapping[str, s msg = f"Metric '{metric_name}' is not an up down counter" raise TypeError(msg) - def update_gauge(self, metric_name: str, value: float, labels: Mapping[str, str]): + def update_gauge(self, metric_name: str, value: float, labels: Mapping[str, str]) -> None: self.validate_labels(metric_name, labels) gauge = self._metrics.get(metric_name) if isinstance(gauge, ObservableGaugeWrapper): @@ -230,7 +232,7 @@ def update_gauge(self, metric_name: str, value: float, labels: Mapping[str, str] msg = f"Metric '{metric_name}' is not a gauge" raise TypeError(msg) - def observe_histogram(self, metric_name: str, value: float, labels: Mapping[str, str]): + def observe_histogram(self, metric_name: str, value: float, labels: Mapping[str, str]) -> None: self.validate_labels(metric_name, labels) histogram = self._metrics.get(metric_name) if isinstance(histogram, Histogram): diff --git a/src/backend/base/langflow/services/telemetry/service.py b/src/backend/base/langflow/services/telemetry/service.py index 65fc851f4a1e..cb396981f5b9 100644 --- a/src/backend/base/langflow/services/telemetry/service.py +++ b/src/backend/base/langflow/services/telemetry/service.py @@ -46,7 +46,7 @@ def __init__(self, settings_service: SettingsService): os.getenv("DO_NOT_TRACK", "False").lower() == "true" or settings_service.settings.do_not_track ) - async def telemetry_worker(self): + async def telemetry_worker(self) -> None: while self.running: func, payload, path = await self.telemetry_queue.get() try: @@ -56,7 +56,7 @@ async def telemetry_worker(self): finally: self.telemetry_queue.task_done() - async def send_telemetry_data(self, payload: BaseModel, path: str | None = None): + async def send_telemetry_data(self, payload: BaseModel, path: str | None = None) -> None: if self.do_not_track: logger.debug("Telemetry tracking is disabled.") return @@ -78,19 +78,19 @@ async def send_telemetry_data(self, payload: BaseModel, path: str | None = None) except Exception: # noqa: BLE001 logger.exception("Unexpected error occurred") - async def log_package_run(self, payload: RunPayload): + async def log_package_run(self, payload: RunPayload) -> None: await self._queue_event((self.send_telemetry_data, payload, "run")) - async def log_package_shutdown(self): + async def log_package_shutdown(self) -> None: payload = ShutdownPayload(time_running=(datetime.now(timezone.utc) - self._start_time).seconds) await self._queue_event(payload) - async def _queue_event(self, payload): + async def _queue_event(self, payload) -> None: if self.do_not_track or self._stopping: return await self.telemetry_queue.put(payload) - async def log_package_version(self): + async def log_package_version(self) -> None: python_version = ".".join(platform.python_version().split(".")[:2]) version_info = get_version_info() architecture = platform.architecture()[0] @@ -106,13 +106,13 @@ async def log_package_version(self): ) await self._queue_event((self.send_telemetry_data, payload, None)) - async def log_package_playground(self, payload: PlaygroundPayload): + async def log_package_playground(self, payload: PlaygroundPayload) -> None: await self._queue_event((self.send_telemetry_data, payload, "playground")) - async def log_package_component(self, payload: ComponentPayload): + async def log_package_component(self, payload: ComponentPayload) -> None: await self._queue_event((self.send_telemetry_data, payload, "component")) - async def start(self): + async def start(self) -> None: if self.running or self.do_not_track: return try: @@ -123,7 +123,7 @@ async def start(self): except Exception: # noqa: BLE001 logger.exception("Error starting telemetry service") - async def flush(self): + async def flush(self) -> None: if self.do_not_track: return try: @@ -131,7 +131,7 @@ async def flush(self): except Exception: # noqa: BLE001 logger.exception("Error flushing logs") - async def stop(self): + async def stop(self) -> None: if self.do_not_track or self._stopping: return try: @@ -147,5 +147,5 @@ async def stop(self): except Exception: # noqa: BLE001 logger.exception("Error stopping tracing service") - async def teardown(self): + async def teardown(self) -> None: await self.stop() diff --git a/src/backend/base/langflow/services/tracing/base.py b/src/backend/base/langflow/services/tracing/base.py index 98763526f285..9b51c6e38d91 100644 --- a/src/backend/base/langflow/services/tracing/base.py +++ b/src/backend/base/langflow/services/tracing/base.py @@ -14,6 +14,8 @@ class BaseTracer(ABC): + trace_id: UUID + @abstractmethod def __init__(self, trace_name: str, trace_type: str, project_name: str, trace_id: UUID): raise NotImplementedError @@ -32,7 +34,7 @@ def add_trace( inputs: dict[str, Any], metadata: dict[str, Any] | None = None, vertex: Vertex | None = None, - ): + ) -> None: raise NotImplementedError @abstractmethod @@ -43,7 +45,7 @@ def end_trace( outputs: dict[str, Any] | None = None, error: Exception | None = None, logs: Sequence[Log | dict] = (), - ): + ) -> None: raise NotImplementedError @abstractmethod @@ -53,7 +55,7 @@ def end( outputs: dict[str, Any], error: Exception | None = None, metadata: dict[str, Any] | None = None, - ): + ) -> None: raise NotImplementedError @abstractmethod diff --git a/src/backend/base/langflow/services/tracing/factory.py b/src/backend/base/langflow/services/tracing/factory.py index 5b8fe5508bdb..f1971622e7c8 100644 --- a/src/backend/base/langflow/services/tracing/factory.py +++ b/src/backend/base/langflow/services/tracing/factory.py @@ -10,7 +10,7 @@ class TracingServiceFactory(ServiceFactory): - def __init__(self): + def __init__(self) -> None: super().__init__(TracingService) def create(self, settings_service: SettingsService): diff --git a/src/backend/base/langflow/services/tracing/langfuse.py b/src/backend/base/langflow/services/tracing/langfuse.py index 5f66f0a12d64..0bcb192d5b9f 100644 --- a/src/backend/base/langflow/services/tracing/langfuse.py +++ b/src/backend/base/langflow/services/tracing/langfuse.py @@ -73,7 +73,7 @@ def add_trace( inputs: dict[str, Any], metadata: dict[str, Any] | None = None, vertex: Vertex | None = None, - ): + ) -> None: start_time = datetime.now(tz=timezone.utc) if not self._ready: return @@ -103,7 +103,7 @@ def end_trace( outputs: dict[str, Any] | None = None, error: Exception | None = None, logs: Sequence[Log | dict] = (), - ): + ) -> None: end_time = datetime.now(tz=timezone.utc) if not self._ready: return @@ -124,7 +124,7 @@ def end( outputs: dict[str, Any], error: Exception | None = None, metadata: dict[str, Any] | None = None, - ): + ) -> None: if not self._ready: return diff --git a/src/backend/base/langflow/services/tracing/langsmith.py b/src/backend/base/langflow/services/tracing/langsmith.py index fea4ea391ad8..4c57124490ee 100644 --- a/src/backend/base/langflow/services/tracing/langsmith.py +++ b/src/backend/base/langflow/services/tracing/langsmith.py @@ -50,7 +50,7 @@ def __init__(self, trace_name: str, trace_type: str, project_name: str, trace_id def ready(self): return self._ready - def setup_langsmith(self): + def setup_langsmith(self) -> bool: if os.getenv("LANGCHAIN_API_KEY") is None: return False try: @@ -72,7 +72,7 @@ def add_trace( inputs: dict[str, Any], metadata: dict[str, Any] | None = None, vertex: Vertex | None = None, - ): + ) -> None: if not self._ready: return processed_inputs = {} @@ -125,7 +125,7 @@ def end_trace( outputs: dict[str, Any] | None = None, error: Exception | None = None, logs: Sequence[Log | dict] = (), - ): + ) -> None: if not self._ready: return child = self._children[trace_name] @@ -158,7 +158,7 @@ def end( outputs: dict[str, Any], error: Exception | None = None, metadata: dict[str, Any] | None = None, - ): + ) -> None: if not self._ready: return self._run_tree.add_metadata({"inputs": inputs}) diff --git a/src/backend/base/langflow/services/tracing/langwatch.py b/src/backend/base/langflow/services/tracing/langwatch.py index 163181750099..247da54b7e3d 100644 --- a/src/backend/base/langflow/services/tracing/langwatch.py +++ b/src/backend/base/langflow/services/tracing/langwatch.py @@ -55,7 +55,7 @@ def __init__(self, trace_name: str, trace_type: str, project_name: str, trace_id def ready(self): return self._ready - def setup_langwatch(self): + def setup_langwatch(self) -> bool: try: import langwatch @@ -74,7 +74,7 @@ def add_trace( inputs: dict[str, Any], metadata: dict[str, Any] | None = None, vertex: Vertex | None = None, - ): + ) -> None: if not self._ready: return # If user is not using session_id, then it becomes the same as flow_id, but @@ -109,7 +109,7 @@ def end_trace( outputs: dict[str, Any] | None = None, error: Exception | None = None, logs: Sequence[Log | dict] = (), - ): + ) -> None: if not self._ready: return if self.spans.get(trace_id): @@ -121,7 +121,7 @@ def end( outputs: dict[str, Any], error: Exception | None = None, metadata: dict[str, Any] | None = None, - ): + ) -> None: if not self._ready: return self.trace.root_span.end( diff --git a/src/backend/base/langflow/services/tracing/service.py b/src/backend/base/langflow/services/tracing/service.py index 7fca5e36cd12..1db7c75b5ad0 100644 --- a/src/backend/base/langflow/services/tracing/service.py +++ b/src/backend/base/langflow/services/tracing/service.py @@ -51,15 +51,15 @@ def __init__(self, settings_service: SettingsService): self.outputs_metadata: dict[str, dict] = defaultdict(dict) self.run_name: str | None = None self.run_id: UUID | None = None - self.project_name = None + self.project_name: str | None = None self._tracers: dict[str, BaseTracer] = {} self._logs: dict[str, list[Log | dict[Any, Any]]] = defaultdict(list) self.logs_queue: asyncio.Queue = asyncio.Queue() self.running = False - self.worker_task = None + self.worker_task: asyncio.Task | None = None self.end_trace_tasks: set[asyncio.Task] = set() - async def log_worker(self): + async def log_worker(self) -> None: while self.running or not self.logs_queue.empty(): log_func, args = await self.logs_queue.get() try: @@ -69,7 +69,7 @@ async def log_worker(self): finally: self.logs_queue.task_done() - async def start(self): + async def start(self) -> None: if self.running: return try: @@ -78,13 +78,13 @@ async def start(self): except Exception: # noqa: BLE001 logger.exception("Error starting tracing service") - async def flush(self): + async def flush(self) -> None: try: await self.logs_queue.join() except Exception: # noqa: BLE001 logger.exception("Error flushing logs") - async def stop(self): + async def stop(self) -> None: try: self.running = False await self.flush() @@ -98,13 +98,13 @@ async def stop(self): except Exception: # noqa: BLE001 logger.exception("Error stopping tracing service") - def _reset_io(self): + def _reset_io(self) -> None: self.inputs = defaultdict(dict) self.inputs_metadata = defaultdict(dict) self.outputs = defaultdict(dict) self.outputs_metadata = defaultdict(dict) - async def initialize_tracers(self): + async def initialize_tracers(self) -> None: try: await self.start() self._initialize_langsmith_tracer() @@ -113,7 +113,7 @@ async def initialize_tracers(self): except Exception: # noqa: BLE001 logger.opt(exception=True).debug("Error initializing tracers") - def _initialize_langsmith_tracer(self): + def _initialize_langsmith_tracer(self) -> None: project_name = os.getenv("LANGCHAIN_PROJECT", "Langflow") self.project_name = project_name langsmith_tracer = _get_langsmith_tracer() @@ -124,7 +124,7 @@ def _initialize_langsmith_tracer(self): trace_id=self.run_id, ) - def _initialize_langwatch_tracer(self): + def _initialize_langwatch_tracer(self) -> None: if "langwatch" not in self._tracers or self._tracers["langwatch"].trace_id != self.run_id: langwatch_tracer = _get_langwatch_tracer() self._tracers["langwatch"] = langwatch_tracer( @@ -134,7 +134,7 @@ def _initialize_langwatch_tracer(self): trace_id=self.run_id, ) - def _initialize_langfuse_tracer(self): + def _initialize_langfuse_tracer(self) -> None: self.project_name = os.getenv("LANGCHAIN_PROJECT", "Langflow") langfuse_tracer = _get_langfuse_tracer() self._tracers["langfuse"] = langfuse_tracer( @@ -144,10 +144,10 @@ def _initialize_langfuse_tracer(self): trace_id=self.run_id, ) - def set_run_name(self, name: str): + def set_run_name(self, name: str) -> None: self.run_name = name - def set_run_id(self, run_id: UUID): + def set_run_id(self, run_id: UUID) -> None: self.run_id = run_id def _start_traces( @@ -158,7 +158,7 @@ def _start_traces( inputs: dict[str, Any], metadata: dict[str, Any] | None = None, vertex: Vertex | None = None, - ): + ) -> None: inputs = self._cleanup_inputs(inputs) self.inputs[trace_name] = inputs self.inputs_metadata[trace_name] = metadata or {} @@ -170,7 +170,7 @@ def _start_traces( except Exception: # noqa: BLE001 logger.exception(f"Error starting trace {trace_name}") - def _end_traces(self, trace_id: str, trace_name: str, error: Exception | None = None): + def _end_traces(self, trace_id: str, trace_name: str, error: Exception | None = None) -> None: for tracer in self._tracers.values(): if tracer.ready: try: @@ -184,7 +184,7 @@ def _end_traces(self, trace_id: str, trace_name: str, error: Exception | None = except Exception: # noqa: BLE001 logger.exception(f"Error ending trace {trace_name}") - def _end_all_traces(self, outputs: dict, error: Exception | None = None): + def _end_all_traces(self, outputs: dict, error: Exception | None = None) -> None: for tracer in self._tracers.values(): if tracer.ready: try: @@ -192,12 +192,12 @@ def _end_all_traces(self, outputs: dict, error: Exception | None = None): except Exception: # noqa: BLE001 logger.exception("Error ending all traces") - async def end(self, outputs: dict, error: Exception | None = None): + async def end(self, outputs: dict, error: Exception | None = None) -> None: await asyncio.to_thread(self._end_all_traces, outputs, error) self._reset_io() await self.stop() - def add_log(self, trace_name: str, log: Log): + def add_log(self, trace_name: str, log: Log) -> None: self._logs[trace_name].append(log) @asynccontextmanager @@ -228,7 +228,7 @@ async def trace_context( else: self._end_and_reset(trace_id, trace_name) - def _end_and_reset(self, trace_id: str, trace_name: str, error: Exception | None = None): + def _end_and_reset(self, trace_id: str, trace_name: str, error: Exception | None = None) -> None: task = asyncio.create_task(asyncio.to_thread(self._end_traces, trace_id, trace_name, error)) self.end_trace_tasks.add(task) task.add_done_callback(self.end_trace_tasks.discard) @@ -239,7 +239,7 @@ def set_outputs( trace_name: str, outputs: dict[str, Any], output_metadata: dict[str, Any] | None = None, - ): + ) -> None: self.outputs[trace_name] |= outputs or {} self.outputs_metadata[trace_name] |= output_metadata or {} diff --git a/src/backend/base/langflow/services/utils.py b/src/backend/base/langflow/services/utils.py index 30910aae5f0b..b1cb950da213 100644 --- a/src/backend/base/langflow/services/utils.py +++ b/src/backend/base/langflow/services/utils.py @@ -60,7 +60,7 @@ def get_or_create_super_user(session: Session, username, password, is_default): logger.opt(exception=True).debug("Error creating superuser.") -def setup_superuser(settings_service, session: Session): +def setup_superuser(settings_service, session: Session) -> None: if settings_service.auth_settings.AUTO_LOGIN: logger.debug("AUTO_LOGIN is set to True. Creating default superuser.") else: @@ -84,7 +84,7 @@ def setup_superuser(settings_service, session: Session): settings_service.auth_settings.reset_credentials() -def teardown_superuser(settings_service, session): +def teardown_superuser(settings_service, session) -> None: """Teardown the superuser.""" # If AUTO_LOGIN is True, we will remove the default superuser # from the database. @@ -110,7 +110,7 @@ def teardown_superuser(settings_service, session): raise RuntimeError(msg) from exc -async def teardown_services(): +async def teardown_services() -> None: """Teardown all the services.""" try: teardown_superuser(get_settings_service(), next(get_session())) @@ -124,14 +124,14 @@ async def teardown_services(): logger.exception(exc) -def initialize_settings_service(): +def initialize_settings_service() -> None: """Initialize the settings manager.""" from langflow.services.settings import factory as settings_factory get_service(ServiceType.SETTINGS_SERVICE, settings_factory.SettingsServiceFactory()) -def initialize_session_service(): +def initialize_session_service() -> None: """Initialize the session manager.""" from langflow.services.cache import factory as cache_factory from langflow.services.session import factory as session_service_factory @@ -149,7 +149,7 @@ def initialize_session_service(): ) -def initialize_services(*, fix_migration: bool = False): +def initialize_services(*, fix_migration: bool = False) -> None: """Initialize all the services needed.""" # Test cache connection get_service(ServiceType.CACHE_SERVICE, default=CacheServiceFactory()) diff --git a/src/backend/base/langflow/services/variable/factory.py b/src/backend/base/langflow/services/variable/factory.py index 60a327673b73..160bb92f6eb4 100644 --- a/src/backend/base/langflow/services/variable/factory.py +++ b/src/backend/base/langflow/services/variable/factory.py @@ -10,7 +10,7 @@ class VariableServiceFactory(ServiceFactory): - def __init__(self): + def __init__(self) -> None: super().__init__(VariableService) def create(self, settings_service: SettingsService): diff --git a/src/backend/base/langflow/services/variable/kubernetes.py b/src/backend/base/langflow/services/variable/kubernetes.py index 18d0b8ae3f38..b5206f2333fc 100644 --- a/src/backend/base/langflow/services/variable/kubernetes.py +++ b/src/backend/base/langflow/services/variable/kubernetes.py @@ -28,7 +28,7 @@ def __init__(self, settings_service: SettingsService): self.kubernetes_secrets = KubernetesSecretManager() @override - def initialize_user_variables(self, user_id: UUID | str, session: Session): + def initialize_user_variables(self, user_id: UUID | str, session: Session) -> None: # Check for environment variables that should be stored in the database should_or_should_not = "Should" if self.settings_service.settings.store_environment_variables else "Should not" logger.info(f"{should_or_should_not} store environment variables in the kubernetes.") diff --git a/src/backend/base/langflow/services/variable/service.py b/src/backend/base/langflow/services/variable/service.py index 57de1d133b57..11eb086218d5 100644 --- a/src/backend/base/langflow/services/variable/service.py +++ b/src/backend/base/langflow/services/variable/service.py @@ -24,7 +24,7 @@ class DatabaseVariableService(VariableService, Service): def __init__(self, settings_service: SettingsService): self.settings_service = settings_service - def initialize_user_variables(self, user_id: UUID | str, session: Session): + def initialize_user_variables(self, user_id: UUID | str, session: Session) -> None: if not self.settings_service.settings.store_environment_variables: logger.info("Skipping environment variable storage.") return @@ -130,7 +130,7 @@ def delete_variable( user_id: UUID | str, name: str, session: Session, - ): + ) -> None: stmt = select(Variable).where(Variable.user_id == user_id).where(Variable.name == name) variable = session.exec(stmt).first() if not variable: @@ -139,7 +139,7 @@ def delete_variable( session.delete(variable) session.commit() - def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: Session): + def delete_variable_by_id(self, user_id: UUID | str, variable_id: UUID, session: Session) -> None: variable = session.exec(select(Variable).where(Variable.user_id == user_id, Variable.id == variable_id)).first() if not variable: msg = f"{variable_id} variable not found." diff --git a/src/backend/base/langflow/settings.py b/src/backend/base/langflow/settings.py index 564f512220cc..14488e84469a 100644 --- a/src/backend/base/langflow/settings.py +++ b/src/backend/base/langflow/settings.py @@ -1,10 +1,10 @@ DEV = False -def _set_dev(value): +def _set_dev(value) -> None: global DEV # noqa: PLW0603 DEV = value -def set_dev(value): +def set_dev(value) -> None: _set_dev(value) diff --git a/src/backend/base/langflow/template/field/base.py b/src/backend/base/langflow/template/field/base.py index 6bc2036a570b..47c7d02c3862 100644 --- a/src/backend/base/langflow/template/field/base.py +++ b/src/backend/base/langflow/template/field/base.py @@ -192,12 +192,12 @@ class Output(BaseModel): def to_dict(self): return self.model_dump(by_alias=True, exclude_none=True) - def add_types(self, _type: list[Any]): + def add_types(self, _type: list[Any]) -> None: if self.types is None: self.types = [] self.types.extend([t for t in _type if t not in self.types]) - def set_selected(self): + def set_selected(self) -> None: if not self.selected and self.types: self.selected = self.types[0] diff --git a/src/backend/base/langflow/template/frontend_node/base.py b/src/backend/base/langflow/template/frontend_node/base.py index c4e3c29ec05b..bf80442e0864 100644 --- a/src/backend/base/langflow/template/frontend_node/base.py +++ b/src/backend/base/langflow/template/frontend_node/base.py @@ -107,7 +107,7 @@ def add_extra_fields(self) -> None: def add_extra_base_classes(self) -> None: pass - def set_base_classes_from_outputs(self): + def set_base_classes_from_outputs(self) -> None: self.base_classes = [output_type for output in self.outputs for output_type in output.types] def validate_component(self) -> None: @@ -180,13 +180,13 @@ def from_inputs(cls, **kwargs): kwargs["template"] = template return cls(**kwargs) - def set_field_value_in_template(self, field_name, value): + def set_field_value_in_template(self, field_name, value) -> None: for field in self.template.fields: if field.name == field_name: field.value = value break - def set_field_load_from_db_in_template(self, field_name, value): + def set_field_load_from_db_in_template(self, field_name, value) -> None: for field in self.template.fields: if field.name == field_name and hasattr(field, "load_from_db"): field.load_from_db = value diff --git a/src/backend/base/langflow/template/template/base.py b/src/backend/base/langflow/template/template/base.py index 9600e84e0592..f526177e25a1 100644 --- a/src/backend/base/langflow/template/template/base.py +++ b/src/backend/base/langflow/template/template/base.py @@ -16,15 +16,15 @@ class Template(BaseModel): def process_fields( self, format_field_func: Callable | None = None, - ): + ) -> None: if format_field_func: for field in self.fields: format_field_func(field, self.type_name) - def sort_fields(self): + def sort_fields(self) -> None: # first sort alphabetically # then sort fields so that fields that have .field_type in DIRECT_TYPES are first - self.fields.sort(key=lambda x: x.name) + self.fields.sort(key=lambda x: x.name or "") self.fields.sort( key=lambda x: x.field_type in DIRECT_TYPES if hasattr(x, "field_type") else False, reverse=False ) diff --git a/src/backend/base/langflow/template/utils.py b/src/backend/base/langflow/template/utils.py index 2b8e15834755..8f92df5ab59a 100644 --- a/src/backend/base/langflow/template/utils.py +++ b/src/backend/base/langflow/template/utils.py @@ -27,7 +27,7 @@ def get_file_path_value(file_path): return file_path -def update_template_field(new_template, key, previous_value_dict): +def update_template_field(new_template, key, previous_value_dict) -> None: """Updates a specific field in the frontend template.""" template_field = new_template.get(key) if not template_field or template_field.get("type") != previous_value_dict.get("type"): @@ -54,7 +54,7 @@ def is_valid_data(frontend_node, raw_frontend_data): return frontend_node and "template" in frontend_node and raw_frontend_data_is_valid(raw_frontend_data) -def update_template_values(new_template, previous_template): +def update_template_values(new_template, previous_template) -> None: """Updates the frontend template with values from the raw template.""" for key, previous_value_dict in previous_template.items(): if key == "code" or not isinstance(previous_value_dict, dict): diff --git a/src/backend/base/langflow/utils/async_helpers.py b/src/backend/base/langflow/utils/async_helpers.py index 4180555a4f82..04cfbacaef26 100644 --- a/src/backend/base/langflow/utils/async_helpers.py +++ b/src/backend/base/langflow/utils/async_helpers.py @@ -18,7 +18,7 @@ def run_in_thread(coro): result = None exception = None - def target(): + def target() -> None: nonlocal result, exception try: result = asyncio.run(coro) diff --git a/src/backend/base/langflow/utils/concurrency.py b/src/backend/base/langflow/utils/concurrency.py index 33fa1eb42172..4d47aa916cf6 100644 --- a/src/backend/base/langflow/utils/concurrency.py +++ b/src/backend/base/langflow/utils/concurrency.py @@ -10,8 +10,8 @@ class KeyedMemoryLockManager: """A manager for acquiring and releasing memory locks based on a key.""" - def __init__(self): - self.locks = {} + def __init__(self) -> None: + self.locks: dict[str, threading.Lock] = {} self.global_lock = threading.Lock() def _get_lock(self, key: str): @@ -33,7 +33,7 @@ def lock(self, key: str): class KeyedWorkerLockManager: """A manager for acquiring locks between workers based on a key.""" - def __init__(self): + def __init__(self) -> None: self.locks_dir = Path(user_cache_dir("langflow"), ensure_exists=True) / "worker_locks" def _validate_key(self, key: str) -> bool: diff --git a/src/backend/base/langflow/utils/connection_string_parser.py b/src/backend/base/langflow/utils/connection_string_parser.py index cefb1f2818ee..a334cef28e48 100644 --- a/src/backend/base/langflow/utils/connection_string_parser.py +++ b/src/backend/base/langflow/utils/connection_string_parser.py @@ -1,7 +1,7 @@ from urllib.parse import quote -def transform_connection_string(connection_string): +def transform_connection_string(connection_string) -> str: auth_part, db_url_name = connection_string.rsplit("@", 1) protocol_user, password_string = auth_part.rsplit(":", 1) encoded_password = quote(password_string) diff --git a/src/backend/base/langflow/utils/lazy_load.py b/src/backend/base/langflow/utils/lazy_load.py index df0130acc5f5..ebacd3480c87 100644 --- a/src/backend/base/langflow/utils/lazy_load.py +++ b/src/backend/base/langflow/utils/lazy_load.py @@ -1,5 +1,5 @@ class LazyLoadDictBase: - def __init__(self): + def __init__(self) -> None: self._all_types_dict = None @property diff --git a/src/backend/base/langflow/utils/schemas.py b/src/backend/base/langflow/utils/schemas.py index 7b9c21e1cd93..76b494af9baa 100644 --- a/src/backend/base/langflow/utils/schemas.py +++ b/src/backend/base/langflow/utils/schemas.py @@ -108,7 +108,7 @@ class DataOutputResponse(BaseModel): class ContainsEnumMeta(enum.EnumMeta): - def __contains__(cls, item): + def __contains__(cls, item) -> bool: try: cls(item) except ValueError: diff --git a/src/backend/base/langflow/utils/util.py b/src/backend/base/langflow/utils/util.py index 946b3865713f..219025013050 100644 --- a/src/backend/base/langflow/utils/util.py +++ b/src/backend/base/langflow/utils/util.py @@ -414,7 +414,7 @@ def update_settings( auto_saving_interval: int = 1000, health_check_max_retries: int = 5, max_file_size_upload: int = 100, -): +) -> None: """Update the settings from a config file.""" from langflow.services.utils import initialize_settings_service diff --git a/src/backend/base/langflow/utils/validate.py b/src/backend/base/langflow/utils/validate.py index b1be733f71b4..5178cf735720 100644 --- a/src/backend/base/langflow/utils/validate.py +++ b/src/backend/base/langflow/utils/validate.py @@ -10,13 +10,13 @@ from langflow.field_typing.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES -def add_type_ignores(): +def add_type_ignores() -> None: if not hasattr(ast, "TypeIgnore"): class TypeIgnore(ast.AST): _fields = () - ast.TypeIgnore = TypeIgnore + ast.TypeIgnore = TypeIgnore # type: ignore[assignment, misc] def validate_code(code): diff --git a/src/backend/base/pyproject.toml b/src/backend/base/pyproject.toml index b2dc802ec6fd..2bfdd4fdaa83 100644 --- a/src/backend/base/pyproject.toml +++ b/src/backend/base/pyproject.toml @@ -28,6 +28,7 @@ exclude = ["langflow/alembic"] line-length = 120 [tool.ruff.lint] +flake8-annotations.mypy-init-return = true flake8-bugbear.extend-immutable-calls = [ "fastapi.Depends", "fastapi.File",