Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: Auto-fix some ruff ANN rules #4210

Merged
merged 4 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/backend/base/langflow/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_number_of_workers(workers=None):
return workers


def display_results(results):
def display_results(results) -> None:
"""Display the results of the migration."""
for table_results in results:
table = Table(title=f"Migration {table_results.table_name}")
Expand All @@ -62,7 +62,7 @@ def display_results(results):
console.print() # Print a new line


def set_var_for_macos_issue():
def set_var_for_macos_issue() -> None:
# OBJC_DISABLE_INITIALIZE_FORK_SAFETY=YES
# we need to set this var is we are running on MacOS
# otherwise we get an error when running gunicorn
Expand Down Expand Up @@ -146,7 +146,7 @@ def run(
help="Defines the maximum file size for the upload in MB.",
show_default=False,
),
):
) -> None:
"""Run Langflow."""
configure(log_level=log_level, log_file=log_file)
set_var_for_macos_issue()
Expand Down Expand Up @@ -202,7 +202,7 @@ def run(
# Run using uvicorn on MacOS and Windows
# Windows doesn't support gunicorn
# MacOS requires an env variable to be set to use gunicorn
process = run_on_windows(host, port, log_level, options, app)
run_on_windows(host, port, log_level, options, app)
else:
# Run using gunicorn on Linux
process = run_on_mac_or_linux(host, port, log_level, options, app)
Expand All @@ -219,7 +219,7 @@ def run(
sys.exit(1)


def wait_for_server_ready(host, port):
def wait_for_server_ready(host, port) -> None:
"""Wait for the server to become ready by polling the health endpoint."""
status_code = 0
while status_code != httpx.codes.OK:
Expand All @@ -241,7 +241,7 @@ def run_on_mac_or_linux(host, port, log_level, options, app):
return webapp_process


def run_on_windows(host, port, log_level, options, app):
def run_on_windows(host, port, log_level, options, app) -> None:
"""Run the Langflow server on Windows."""
print_banner(host, port)
run_langflow(host, port, log_level, options, app)
Expand Down Expand Up @@ -275,7 +275,7 @@ def get_free_port(port):
return port


def get_letter_from_version(version: str):
def get_letter_from_version(version: str) -> str | None:
"""Get the letter from a pre-release version."""
if "a" in version:
return "a"
Expand All @@ -294,7 +294,7 @@ def build_version_notice(current_version: str, package_name: str) -> str:
return ""


def generate_pip_command(package_names, is_pre_release):
def generate_pip_command(package_names, is_pre_release) -> str:
"""Generate the pip install command based on the packages and whether it's a pre-release."""
base_command = "pip install"
if is_pre_release:
Expand All @@ -309,7 +309,7 @@ def stylize_text(text: str, to_style: str, *, is_prerelease: bool) -> str:
return text.replace(to_style, styled_text)


def print_banner(host: str, port: int):
def print_banner(host: str, port: int) -> None:
notices = []
package_names = [] # Track package names for pip install instructions
is_pre_release = False # Track if any package is a pre-release
Expand Down Expand Up @@ -355,7 +355,7 @@ def print_banner(host: str, port: int):
rprint(panel)


def run_langflow(host, port, log_level, options, app):
def run_langflow(host, port, log_level, options, app) -> None:
"""Run Langflow server on localhost."""
if platform.system() == "Windows":
# Run using uvicorn on MacOS and Windows
Expand All @@ -381,7 +381,7 @@ def superuser(
username: str = typer.Option(..., prompt=True, help="Username for the superuser."),
password: str = typer.Option(..., prompt=True, hide_input=True, help="Password for the superuser."),
log_level: str = typer.Option("error", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"),
):
) -> None:
"""Create a superuser."""
configure(log_level=log_level)
initialize_services()
Expand Down Expand Up @@ -413,7 +413,7 @@ def superuser(
# command to copy the langflow database from the cache to the current directory
# because now the database is stored per installation
@app.command()
def copy_db():
def copy_db() -> None:
"""Copy the database files to the current directory.

This function copies the 'langflow.db' and 'langflow-pre.db' files from the cache directory to the current
Expand Down Expand Up @@ -452,7 +452,7 @@ def migration(
default=False,
help="Fix migrations. This is a destructive operation, and should only be used if you know what you are doing.",
),
):
) -> None:
"""Run or test migrations."""
if fix and not typer.confirm(
"This will delete all data necessary to fix migrations. Are you sure you want to continue?"
Expand All @@ -470,7 +470,7 @@ def migration(
@app.command()
def api_key(
log_level: str = typer.Option("error", help="Logging level."),
):
) -> None:
"""Creates an API key for the default superuser if AUTO_LOGIN is enabled.

Args:
Expand Down Expand Up @@ -510,7 +510,7 @@ def api_key(
api_key_banner(unmasked_api_key)


def api_key_banner(unmasked_api_key):
def api_key_banner(unmasked_api_key) -> None:
is_mac = platform.system() == "Darwin"
import pyperclip

Expand All @@ -529,7 +529,7 @@ def api_key_banner(unmasked_api_key):
console.print(panel)


def main():
def main() -> None:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
app()
Expand Down
4 changes: 2 additions & 2 deletions src/backend/base/langflow/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_is_component_from_data(data: dict):
return data.get("is_component")


async def check_langflow_version(component: StoreComponentCreate):
async def check_langflow_version(component: StoreComponentCreate) -> None:
from langflow.utils.version import get_version_info

__version__ = get_version_info()["version"]
Expand Down Expand Up @@ -259,7 +259,7 @@ def parse_value(value: Any, input_type: str) -> Any:
return value


async def cascade_delete_flow(session: Session, flow: Flow):
async def cascade_delete_flow(session: Session, flow: Flow) -> None:
try:
session.exec(delete(TransactionTable).where(TransactionTable.flow_id == flow.id))
session.exec(delete(VertexBuildTable).where(VertexBuildTable.flow_id == flow.id))
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/api/v1/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ async def on_text( # type: ignore[misc]
@override
async def on_agent_action( # type: ignore[misc]
self, action: AgentAction, **kwargs: Any
):
) -> None:
log = f"Thought: {action.log}"
# if there are line breaks, split them and send them
# as separate messages
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ async def consume_and_yield(queue: asyncio.Queue, client_consumed_queue: asyncio
event_manager = create_default_event_manager(queue=asyncio_queue)
main_task = asyncio.create_task(event_generator(event_manager, asyncio_queue_client_consumed))

def on_disconnect():
def on_disconnect() -> None:
logger.debug("Client disconnected, closing tasks")
main_task.cancel()

Expand Down
7 changes: 2 additions & 5 deletions src/backend/base/langflow/api/v1/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
CustomComponentRequest,
CustomComponentResponse,
InputValueRequest,
ProcessResponse,
RunResponse,
SidebarCategoriesResponse,
SimplifiedAPIRequest,
Expand Down Expand Up @@ -68,7 +67,7 @@ async def get_all():
raise HTTPException(status_code=500, detail=str(exc)) from exc


def validate_input_and_tweaks(input_request: SimplifiedAPIRequest):
def validate_input_and_tweaks(input_request: SimplifiedAPIRequest) -> None:
# If the input_value is not None and the input_type is "chat"
# then we need to check the tweaks if the ChatInput component is present
# and if its input_value is not None
Expand Down Expand Up @@ -483,15 +482,13 @@ async def experimental_run_flow(

@router.post(
"/predict/{flow_id}",
response_model=ProcessResponse,
dependencies=[Depends(api_key_security)],
)
@router.post(
"/process/{flow_id}",
response_model=ProcessResponse,
dependencies=[Depends(api_key_security)],
)
async def process():
async def process() -> None:
"""Endpoint to process an input with a given flow_id."""
# Raise a depreciation warning
logger.warning(
Expand Down
4 changes: 2 additions & 2 deletions src/backend/base/langflow/api/v1/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def get_vertex_builds(
async def delete_vertex_builds(
flow_id: Annotated[UUID, Query()],
session: Annotated[Session, Depends(get_session)],
):
) -> None:
try:
delete_vertex_builds_by_flow_id(session, flow_id)
except Exception as e:
Expand Down Expand Up @@ -75,7 +75,7 @@ async def get_messages(
async def delete_messages(
message_ids: list[UUID],
session: Annotated[Session, Depends(get_session)],
):
) -> None:
try:
session.exec(delete(MessageTable).where(MessageTable.id.in_(message_ids))) # type: ignore[attr-defined]
session.commit()
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/api/v1/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def delete_variable(
variable_id: UUID,
current_user: User = Depends(get_current_active_user),
variable_service: VariableService = Depends(get_variable_service),
):
) -> None:
"""Delete a variable."""
try:
variable_service.delete_variable_by_id(user_id=current_user.id, variable_id=variable_id, session=session)
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/base/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def message_response(self) -> Message:
self.status = message
return message

def _validate_outputs(self):
def _validate_outputs(self) -> None:
required_output_methods = ["build_agent"]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:
Expand Down
4 changes: 2 additions & 2 deletions src/backend/base/langflow/base/agents/crewai/crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def build_crew(self) -> Crew:
def get_task_callback(
self,
) -> Callable:
def task_callback(task_output: TaskOutput):
def task_callback(task_output: TaskOutput) -> None:
vertex_id = self._vertex.id if self._vertex else self.display_name or self.__class__.__name__
self.log(task_output.model_dump(), name=f"Task (Agent: {task_output.agent}) - {vertex_id}")

Expand All @@ -61,7 +61,7 @@ def task_callback(task_output: TaskOutput):
def get_step_callback(
self,
) -> Callable:
def step_callback(agent_output: AgentFinish | list[tuple[AgentAction, str]]):
def step_callback(agent_output: AgentFinish | list[tuple[AgentAction, str]]) -> None:
_id = self._vertex.id if self._vertex else self.display_name
if isinstance(agent_output, AgentFinish):
messages = agent_output.messages
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/base/astra_assistants/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_patched_openai_client(shared_component_cache):
tools_and_names = {}


def tools_from_package(your_package):
def tools_from_package(your_package) -> None:
# Iterate over all modules in the package
package_name = your_package.__name__
for module_info in pkgutil.iter_modules(your_package.__path__):
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/base/chains/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class LCChainComponent(Component):

outputs = [Output(display_name="Text", name="text", method="invoke_chain")]

def _validate_outputs(self):
def _validate_outputs(self) -> None:
required_output_methods = ["invoke_chain"]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/base/embeddings/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class LCEmbeddingsModel(Component):
Output(display_name="Embeddings", name="embeddings", method="build_embeddings"),
]

def _validate_outputs(self):
def _validate_outputs(self) -> None:
required_output_methods = ["build_embeddings"]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:
Expand Down
4 changes: 2 additions & 2 deletions src/backend/base/langflow/base/io/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def store_message(self, message: Message) -> Message:
self.status = stored_message
return stored_message

def _send_message_event(self, message: Message):
def _send_message_event(self, message: Message) -> None:
if hasattr(self, "_event_manager") and self._event_manager:
self._event_manager.on_message(data=message.data)

Expand Down Expand Up @@ -107,7 +107,7 @@ def _create_message(self, input_value, sender, sender_name, files, session_id) -
return Message.from_data(input_value)
return Message(text=input_value, sender=sender, sender_name=sender_name, files=files, session_id=session_id)

def _send_messages_events(self, messages):
def _send_messages_events(self, messages) -> None:
if hasattr(self, "_event_manager") and self._event_manager:
for stored_message in messages:
self._event_manager.on_message(data=stored_message.data)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class LCToolComponent(Component):
Output(name="api_build_tool", display_name="Tool", method="build_tool"),
]

def _validate_outputs(self):
def _validate_outputs(self) -> None:
required_output_methods = ["run_model", "build_tool"]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/base/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ def get_messages(self, **kwargs) -> list[Data]:

def add_message(
self, sender: str, sender_name: str, text: str, session_id: str, metadata: dict | None = None, **kwargs
):
) -> None:
raise NotImplementedError
2 changes: 1 addition & 1 deletion src/backend/base/langflow/base/memory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LCChatMemoryComponent(Component):
)
]

def _validate_outputs(self):
def _validate_outputs(self) -> None:
required_output_methods = ["build_message_history"]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/base/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class LCModelComponent(Component):
def _get_exception_message(self, e: Exception):
return str(e)

def _validate_outputs(self):
def _validate_outputs(self) -> None:
# At least these two outputs must be defined
required_output_methods = ["text_response", "build_model"]
output_names = [output.name for output in self.outputs]
Expand Down
8 changes: 4 additions & 4 deletions src/backend/base/langflow/base/prompts/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _check_variable(var, invalid_chars, wrong_variables, empty_variables):
return wrong_variables, empty_variables


def _check_for_errors(input_variables, fixed_variables, wrong_variables, empty_variables):
def _check_for_errors(input_variables, fixed_variables, wrong_variables, empty_variables) -> None:
if any(var for var in input_variables if var not in fixed_variables):
error_message = (
f"Error: Input variables contain invalid characters or formats. \n"
Expand Down Expand Up @@ -159,7 +159,7 @@ def get_old_custom_fields(custom_fields, name):
return old_custom_fields


def add_new_variables_to_template(input_variables, custom_fields, template, name):
def add_new_variables_to_template(input_variables, custom_fields, template, name) -> None:
for variable in input_variables:
try:
template_field = DefaultPromptField(name=variable, display_name=variable)
Expand All @@ -177,7 +177,7 @@ def add_new_variables_to_template(input_variables, custom_fields, template, name
raise HTTPException(status_code=500, detail=str(exc)) from exc


def remove_old_variables_from_template(old_custom_fields, input_variables, custom_fields, template, name):
def remove_old_variables_from_template(old_custom_fields, input_variables, custom_fields, template, name) -> None:
for variable in old_custom_fields:
if variable not in input_variables:
try:
Expand All @@ -192,7 +192,7 @@ def remove_old_variables_from_template(old_custom_fields, input_variables, custo
raise HTTPException(status_code=500, detail=str(exc)) from exc


def update_input_variables_field(input_variables, template):
def update_input_variables_field(input_variables, template) -> None:
if "input_variables" in template:
template["input_variables"]["value"] = input_variables

Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/base/textsplitters/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class LCTextSplitterComponent(LCDocumentTransformerComponent):
trace_type = "text_splitter"

def _validate_outputs(self):
def _validate_outputs(self) -> None:
required_output_methods = ["text_splitter"]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:
Expand Down
Loading
Loading