From c9762174b33a88952267538e523bb3d91fc25e70 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 14 Oct 2025 19:49:47 +0300 Subject: [PATCH 1/8] feat(api-nodes): implement new API client for V3 nodes --- comfy_api_nodes/nodes_bytedance.py | 221 ++---- comfy_api_nodes/util/__init__.py | 15 + comfy_api_nodes/util/_helpers.py | 58 ++ comfy_api_nodes/util/api_client.py | 915 ++++++++++++++++++++++ comfy_api_nodes/util/common_exceptions.py | 14 + comfy_api_nodes/util/storage_helpers.py | 272 +++++++ pyproject.toml | 2 + 7 files changed, 1335 insertions(+), 162 deletions(-) create mode 100644 comfy_api_nodes/util/_helpers.py create mode 100644 comfy_api_nodes/util/api_client.py create mode 100644 comfy_api_nodes/util/common_exceptions.py create mode 100644 comfy_api_nodes/util/storage_helpers.py diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index f3d3f8d3eeab..d97e0a3a8aba 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1,7 +1,7 @@ import logging import math from enum import Enum -from typing import Literal, Optional, Type, Union +from typing import Literal, Optional, Union from typing_extensions import override import torch @@ -13,18 +13,15 @@ get_number_of_images, validate_image_dimensions, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - EmptyRequest, - HttpMethod, - SynchronousOperation, - PollingOperation, - T, + sync_op_pydantic, + poll_op_pydantic, + upload_images_to_comfyapi, ) from comfy_api_nodes.apinode_utils import ( download_url_to_image_tensor, download_url_to_video_output, - upload_images_to_comfyapi, validate_string, image_tensor_pair_to_batch, ) @@ -208,35 +205,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N return None -async def poll_until_finished( - auth_kwargs: dict[str, str], - task_id: str, - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> TaskStatusResponse: - """Polls the ByteDance API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - completed_statuses=[ - "succeeded", - ], - failed_statuses=[ - "cancelled", - "failed", - ], - status_extractor=lambda response: response.status, - auth_kwargs=auth_kwargs, - result_url_extractor=get_video_url_from_task_status, - estimated_duration=estimated_duration, - node_id=node_id, - ).execute() - - class ByteDanceImageNode(IO.ComfyNode): @classmethod @@ -353,20 +321,12 @@ async def execute( guidance_scale=guidance_scale, watermark=watermark, ) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_IMAGE_ENDPOINT, - method=HttpMethod.POST, - request_model=Text2ImageTaskCreationRequest, - response_model=ImageTaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() + response = await sync_op_pydantic( + cls, + endpoint=ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + data=payload, + response_model=ImageTaskCreationResponse, + ) return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) @@ -449,16 +409,7 @@ async def execute( if get_number_of_images(image) != 1: raise ValueError("Exactly one input image is required.") validate_image_aspect_ratio_range(image, (1, 3), (3, 1)) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - source_url = (await upload_images_to_comfyapi( - image, - max_images=1, - mime_type="image/png", - auth_kwargs=auth_kwargs, - ))[0] + source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png",))[0] payload = Image2ImageTaskCreationRequest( model=model, prompt=prompt, @@ -467,16 +418,12 @@ async def execute( guidance_scale=guidance_scale, watermark=watermark, ) - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_IMAGE_ENDPOINT, - method=HttpMethod.POST, - request_model=Image2ImageTaskCreationRequest, - response_model=ImageTaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() + response = await sync_op_pydantic( + cls, + endpoint=ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + data=payload, + response_model=ImageTaskCreationResponse, + ) return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) @@ -621,41 +568,31 @@ async def execute( raise ValueError( "The maximum number of generated images plus the number of reference images cannot exceed 15." ) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } reference_images_urls = [] if n_input_images: for i in image: validate_image_aspect_ratio_range(i, (1, 3), (3, 1)) - reference_images_urls = (await upload_images_to_comfyapi( + reference_images_urls = await upload_images_to_comfyapi( + cls, image, max_images=n_input_images, mime_type="image/png", - auth_kwargs=auth_kwargs, - )) - payload = Seedream4TaskCreationRequest( - model=model, - prompt=prompt, - image=reference_images_urls, - size=f"{w}x{h}", - seed=seed, - sequential_image_generation=sequential_image_generation, - sequential_image_generation_options=Seedream4Options(max_images=max_images), - watermark=watermark, - ) - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_IMAGE_ENDPOINT, - method=HttpMethod.POST, - request_model=Seedream4TaskCreationRequest, - response_model=ImageTaskCreationResponse, + ) + response = await sync_op_pydantic( + cls, + endpoint=ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + response_model=ImageTaskCreationResponse, + data=Seedream4TaskCreationRequest( + model=model, + prompt=prompt, + image=reference_images_urls, + size=f"{w}x{h}", + seed=seed, + sequential_image_generation=sequential_image_generation, + sequential_image_generation_options=Seedream4Options(max_images=max_images), + watermark=watermark, ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - + ) if len(response.data) == 1: return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d] @@ -764,19 +701,9 @@ async def execute( f"--camerafixed {str(camera_fixed).lower()} " f"--watermark {str(watermark).lower()}" ) - - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } return await process_video_task( - request_model=Text2VideoTaskCreationRequest, - payload=Text2VideoTaskCreationRequest( - model=model, - content=[TaskTextContent(text=prompt)], - ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, + cls, + payload=Text2VideoTaskCreationRequest(model=model, content=[TaskTextContent(text=prompt)]), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -879,13 +806,7 @@ async def execute( validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth_kwargs))[0] - + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] prompt = ( f"{prompt} " f"--resolution {resolution} " @@ -897,13 +818,11 @@ async def execute( ) return await process_video_task( - request_model=Image2VideoTaskCreationRequest, + cls, payload=Image2VideoTaskCreationRequest( model=model, content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))], ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -1012,16 +931,11 @@ async def execute( validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - download_urls = await upload_images_to_comfyapi( + cls, image_tensor_pair_to_batch(first_frame, last_frame), max_images=2, mime_type="image/png", - auth_kwargs=auth_kwargs, ) prompt = ( @@ -1035,7 +949,7 @@ async def execute( ) return await process_video_task( - request_model=Image2VideoTaskCreationRequest, + cls, payload=Image2VideoTaskCreationRequest( model=model, content=[ @@ -1044,8 +958,6 @@ async def execute( TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"), ], ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -1141,15 +1053,7 @@ async def execute( validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - image_urls = await upload_images_to_comfyapi( - images, max_images=4, mime_type="image/png", auth_kwargs=auth_kwargs - ) - + image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png") prompt = ( f"{prompt} " f"--resolution {resolution} " @@ -1163,39 +1067,32 @@ async def execute( *[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls] ] return await process_video_task( - request_model=Image2VideoTaskCreationRequest, - payload=Image2VideoTaskCreationRequest( - model=model, - content=x, - ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, + cls, + payload=Image2VideoTaskCreationRequest(model=model, content=x), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) async def process_video_task( - request_model: Type[T], + cls: type[IO.ComfyNode], payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], - auth_kwargs: dict, - node_id: str, estimated_duration: Optional[int], ) -> IO.NodeOutput: - initial_response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_TASK_ENDPOINT, - method=HttpMethod.POST, - request_model=request_model, - response_model=TaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - response = await poll_until_finished( - auth_kwargs, - initial_response.id, + initial_response = await sync_op_pydantic( + cls, + endpoint=ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), + data=payload, + response_model=TaskCreationResponse, + ) + response = await poll_op_pydantic( + cls, + poll_endpoint=ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"), + completed_statuses=["succeeded"], + failed_statuses=["cancelled", "failed"], + queued_states=["queued"], + status_extractor=lambda r: r.status, estimated_duration=estimated_duration, - node_id=node_id, + response_model=TaskStatusResponse, ) return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response))) diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index e69de29bb2d1..cad902fc0d9b 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -0,0 +1,15 @@ +from .api_client import ApiEndpoint, sync_op_pydantic, poll_op_pydantic, sync_op, poll_op +from .storage_helpers import ( + upload_file_to_comfyapi, + upload_images_to_comfyapi, +) + +__all__ = [ + "ApiEndpoint", + "poll_op", + "sync_op", + "poll_op_pydantic", + "sync_op_pydantic", + "upload_file_to_comfyapi", + "upload_images_to_comfyapi", +] diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py new file mode 100644 index 000000000000..1bf951adc48e --- /dev/null +++ b/comfy_api_nodes/util/_helpers.py @@ -0,0 +1,58 @@ +import asyncio +import contextlib +import time +from typing import Optional, Callable + +from comfy_api.latest import IO +from comfy.cli_args import args +from comfy.model_management import processing_interrupted + +from .common_exceptions import ProcessingInterrupted + + +def _is_processing_interrupted() -> bool: + """Return True if user/runtime requested interruption.""" + return processing_interrupted() + + +def _get_node_id(node_cls: type[IO.ComfyNode]) -> str: + return node_cls.hidden.unique_id + + +def _get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]: + if node_cls.hidden.auth_token_comfy_org: + return {"Authorization": f"Bearer {node_cls.hidden.auth_token_comfy_org}"} + if node_cls.hidden.api_key_comfy_org: + return {"X-API-KEY": node_cls.hidden.api_key_comfy_org} + return {} + + +def _default_base_url() -> str: + return getattr(args, "comfy_api_base", "https://api.comfy.org") + + +async def _sleep_with_interrupt( + seconds: float, + node_cls: type[IO.ComfyNode], + label: Optional[str] = None, + start_ts: Optional[float] = None, + estimated_total: Optional[int] = None, + *, + display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None, +): + """ + Sleep in 1s slices while: + - Checking for interruption (raises ProcessingInterrupted). + - Optionally emitting time progress via display_callback (if provided). + """ + end = time.monotonic() + seconds + while True: + if _is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + now = time.monotonic() + if start_ts is not None and label and display_callback: + with contextlib.suppress(Exception): + display_callback(node_cls, label, int(now - start_ts), estimated_total) + if now >= end: + break + await asyncio.sleep(min(1.0, end - now)) diff --git a/comfy_api_nodes/util/api_client.py b/comfy_api_nodes/util/api_client.py new file mode 100644 index 000000000000..95614d6b6987 --- /dev/null +++ b/comfy_api_nodes/util/api_client.py @@ -0,0 +1,915 @@ +import asyncio +import contextlib +import json +import logging +import socket +import time +import uuid +from dataclasses import dataclass +from enum import Enum +from io import BytesIO +from typing import Any, Callable, Optional, Union, Type, TypeVar, Literal + +import aiohttp +from aiohttp.client_exceptions import ClientError, ContentTypeError +from comfy_api.latest import IO +from comfy import utils +from pydantic import BaseModel +from server import PromptServer +from urllib.parse import urljoin, urlparse + +from comfy_api_nodes.apis import request_logger +from .common_exceptions import ProcessingInterrupted, LocalNetworkError, ApiServerError +from ._helpers import ( + _is_processing_interrupted, + _get_node_id, + _get_auth_header, + _default_base_url, + _sleep_with_interrupt, +) + + +M = TypeVar("M", bound=BaseModel) + + +class ApiEndpoint: + def __init__( + self, + path: str, + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET", + *, + query_params: Optional[dict[str, Any]] = None, + headers: Optional[dict[str, str]] = None, + ): + self.path = path + self.method = method + self.query_params = query_params or {} + self.headers = headers or {} + + +@dataclass +class _RequestConfig: + node_cls: type[IO.ComfyNode] + endpoint: ApiEndpoint + timeout: float + content_type: str + data: Optional[dict[str, Any]] + files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] + multipart_parser: Optional[Callable] + max_retries: int + retry_delay: float + retry_backoff: float + wait_label: str = "Waiting" + monitor_progress: bool = True + estimated_total: Optional[int] = None + final_label_on_success: Optional[str] = "Completed" + progress_origin_ts: Optional[float] = None + + +@dataclass +class _PollUIState: + started: float + status_label: str = "Queued" + is_queued: bool = True + price: Optional[float] = None + estimated_duration: Optional[int] = None + base_processing_elapsed: float = 0.0 # sum of completed active intervals + active_since: Optional[float] = None # start time of current active interval (None if queued) + + +_RETRY_STATUS = {408, 429, 500, 502, 503, 504} + + +async def sync_op_pydantic( + cls: type[IO.ComfyNode], + endpoint: ApiEndpoint, + *, + response_model: Type[M], + data: Optional[BaseModel] = None, + files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, + content_type: str = "application/json", + timeout: float = 3600.0, + multipart_parser: Optional[Callable] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: str = "Waiting for server", + estimated_total: Optional[int] = None, + final_label_on_success: Optional[str] = "Completed", + progress_origin_ts: Optional[float] = None, + monitor_progress: bool = True, +) -> M: + raw = await sync_op( + cls, + endpoint, + data=data, + files=files, + content_type=content_type, + timeout=timeout, + multipart_parser=multipart_parser, + max_retries=max_retries, + retry_delay=retry_delay, + retry_backoff=retry_backoff, + wait_label=wait_label, + estimated_total=estimated_total, + as_binary=False, + final_label_on_success=final_label_on_success, + progress_origin_ts=progress_origin_ts, + monitor_progress=monitor_progress, + ) + if not isinstance(raw, dict): + raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + return _validate_or_raise(response_model, raw) + + +async def poll_op_pydantic( + cls: type[IO.ComfyNode], + *, + poll_endpoint: ApiEndpoint, + response_model: Type[M], + status_extractor: Callable[[M], Optional[str]], + progress_extractor: Optional[Callable[[M], Optional[int]]] = None, + price_extractor: Optional[Callable[[M], Optional[float]]] = None, + completed_statuses: list[str], + failed_statuses: list[str], + queued_states: Optional[list[str]] = None, + poll_interval: float = 5.0, + max_poll_attempts: int = 120, + timeout_per_poll: float = 120.0, + max_retries_per_poll: int = 3, + retry_delay_per_poll: float = 1.0, + retry_backoff_per_poll: float = 2.0, + estimated_duration: Optional[int] = None, + cancel_endpoint: Optional[ApiEndpoint] = None, + cancel_timeout: float = 10.0, +) -> M: + raw = await poll_op( + cls, + poll_endpoint=poll_endpoint, + status_extractor=_wrap_model_extractor(response_model, status_extractor), + progress_extractor=_wrap_model_extractor(response_model, progress_extractor), + price_extractor=_wrap_model_extractor(response_model, price_extractor), + completed_statuses=completed_statuses, + failed_statuses=failed_statuses, + queued_states=queued_states, + poll_interval=poll_interval, + max_poll_attempts=max_poll_attempts, + timeout_per_poll=timeout_per_poll, + max_retries_per_poll=max_retries_per_poll, + retry_delay_per_poll=retry_delay_per_poll, + retry_backoff_per_poll=retry_backoff_per_poll, + estimated_duration=estimated_duration, + cancel_endpoint=cancel_endpoint, + cancel_timeout=cancel_timeout, + ) + if not isinstance(raw, dict): + raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + return _validate_or_raise(response_model, raw) + + +async def sync_op( + cls: type[IO.ComfyNode], + endpoint: ApiEndpoint, + *, + data: Optional[Union[dict[str, Any], BaseModel]] = None, + files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, + content_type: str = "application/json", + timeout: float = 3600.0, + multipart_parser: Optional[Callable] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: str = "Waiting for server", + estimated_total: Optional[int] = None, + as_binary: bool = False, + final_label_on_success: Optional[str] = "Completed", + progress_origin_ts: Optional[float] = None, + monitor_progress: bool = True, +) -> Union[dict[str, Any], bytes]: + """ + Make a single network request. + - If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON). + - If as_binary=True: returns bytes. + """ + if isinstance(data, BaseModel): + data = data.model_dump(exclude_none=True) + for k, v in list(data.items()): + if isinstance(v, Enum): + data[k] = v.value + cfg = _RequestConfig( + node_cls=cls, + endpoint=endpoint, + timeout=timeout, + content_type=content_type, + data=data, + files=files, + multipart_parser=multipart_parser, + max_retries=max_retries, + retry_delay=retry_delay, + retry_backoff=retry_backoff, + wait_label=wait_label, + monitor_progress=monitor_progress, + estimated_total=estimated_total, + final_label_on_success=final_label_on_success, + progress_origin_ts=progress_origin_ts, + ) + return await _request_base(cfg, expect_binary=as_binary) + + +async def poll_op( + cls: type[IO.ComfyNode], + *, + poll_endpoint: ApiEndpoint, + status_extractor: Callable[[dict[str, Any]], Optional[str]], + progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None, + price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, + completed_statuses: list[str], + failed_statuses: list[str], + queued_states: Optional[list[str]] = None, + poll_interval: float = 5.0, + max_poll_attempts: int = 120, + timeout_per_poll: float = 120.0, + max_retries_per_poll: int = 3, + retry_delay_per_poll: float = 1.0, + retry_backoff_per_poll: float = 2.0, + estimated_duration: Optional[int] = None, + cancel_endpoint: Optional[ApiEndpoint] = None, + cancel_timeout: float = 10.0, +) -> dict[str, Any]: + """ + Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing, + checks interruption every second, and calls Cancel endpoint (if provided) on interruption. + Returns the final JSON response from the poll endpoint. + """ + queued_states = queued_states or [] + started = time.monotonic() + consumed_attempts = 0 # counts only non-queued polls + + progress_bar = utils.ProgressBar(100) if progress_extractor else None + last_progress: Optional[int] = None + + state = _PollUIState(started=started, estimated_duration=estimated_duration) + stop_ticker = asyncio.Event() + + async def _ticker(): + """Emit a UI update every second while polling is in progress.""" + try: + while not stop_ticker.is_set(): + if _is_processing_interrupted(): + break + now = time.monotonic() + proc_elapsed = state.base_processing_elapsed + ( + (now - state.active_since) if state.active_since is not None else 0.0 + ) + _display_time_progress( + cls, + label=state.status_label, + elapsed_seconds=int(now - state.started), + estimated_total=state.estimated_duration, + price=state.price, + is_queued=state.is_queued, + processing_elapsed_seconds=int(proc_elapsed), + ) + await asyncio.sleep(1.0) + except Exception as exc: + logging.debug("Polling ticker exited: %s", exc) + + ticker_task = asyncio.create_task(_ticker()) + try: + while consumed_attempts < max_poll_attempts: + try: + resp_json = await sync_op( + cls, + poll_endpoint, + timeout=timeout_per_poll, + max_retries=max_retries_per_poll, + retry_delay=retry_delay_per_poll, + retry_backoff=retry_backoff_per_poll, + wait_label="Checking", + estimated_total=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + if not isinstance(resp_json, dict): + raise Exception("Polling endpoint returned non-JSON response.") + except ProcessingInterrupted: + if cancel_endpoint: + with contextlib.suppress(Exception): + await sync_op( + cls, + cancel_endpoint, + timeout=cancel_timeout, + max_retries=0, + wait_label="Cancelling task", + estimated_total=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + raise + + try: + status = status_extractor(resp_json) + except Exception as e: + logging.error("Status extraction failed: %s", e) + status = None + + if price_extractor: + new_price = price_extractor(resp_json) + if new_price is not None: + state.price = new_price + + if progress_extractor: + new_progress = progress_extractor(resp_json) + if new_progress is not None and last_progress != new_progress: + progress_bar.update_absolute(new_progress, total=100) + last_progress = new_progress + + now_ts = time.monotonic() + is_queued = status in queued_states + + if is_queued: + if state.active_since is not None: # If we just moved from active -> queued, close the active interval + state.base_processing_elapsed += (now_ts - state.active_since) + state.active_since = None + else: + if state.active_since is None: # If we just moved from queued -> active, open a new active interval + state.active_since = now_ts + + state.is_queued = is_queued + state.status_label = status or ("Queued" if is_queued else "Processing") + if status in completed_statuses: + if state.active_since is not None: + state.base_processing_elapsed += (now_ts - state.active_since) + state.active_since = None + stop_ticker.set() + with contextlib.suppress(Exception): + await ticker_task + + if progress_bar and last_progress != 100: + progress_bar.update_absolute(100, total=100) + + _display_time_progress( + cls, + label=status if status else "Completed", + elapsed_seconds=int(now_ts - started), + estimated_total=estimated_duration, + price=state.price, + is_queued=False, + processing_elapsed_seconds=int(state.base_processing_elapsed), + ) + return resp_json + + if status in failed_statuses: + msg = f"Task failed: {json.dumps(resp_json)}" + logging.error(msg) + raise Exception(msg) + + try: + await _sleep_with_interrupt(poll_interval, cls, None, None, None) + except ProcessingInterrupted: + if cancel_endpoint: + with contextlib.suppress(Exception): + await sync_op( + cls, + cancel_endpoint, + timeout=cancel_timeout, + max_retries=0, + wait_label="Cancelling task", + estimated_total=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + raise + if not is_queued: + consumed_attempts += 1 + + raise Exception( + f"Polling timed out after {max_poll_attempts} non-queued attempts " + f"(~{int(max_poll_attempts * poll_interval)}s of active polling)." + ) + except ProcessingInterrupted: + raise + except (LocalNetworkError, ApiServerError): + raise + except Exception as e: + raise Exception(f"Polling aborted due to error: {e}") from e + finally: + stop_ticker.set() + with contextlib.suppress(Exception): + await ticker_task + + +def _display_text( + node_cls: type[IO.ComfyNode], + text: Optional[str], + *, + status: Optional[str] = None, + price: Optional[float] = None, +) -> None: + display_lines: list[str] = [] + if status: + display_lines.append(f"Status: {status.capitalize()}") + if price is not None: + display_lines.append(f"Price: ${float(price):,.4f}") + if text is not None: + display_lines.append(text) + if display_lines: + PromptServer.instance.send_progress_text("\n".join(display_lines), _get_node_id(node_cls)) + + +def _display_time_progress( + node_cls: type[IO.ComfyNode], + label: str, + elapsed_seconds: int, + estimated_total: Optional[int] = None, + *, + price: Optional[float] = None, + is_queued: Optional[bool] = None, + processing_elapsed_seconds: Optional[int] = None, +) -> None: + if estimated_total is not None and estimated_total > 0 and is_queued is False: + pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds + remaining = max(0, int(estimated_total) - int(pe)) + time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" + else: + time_line = f"Time elapsed: {int(elapsed_seconds)}s" + _display_text(node_cls, time_line, status=label, price=price) + + +async def _diagnose_connectivity() -> dict[str, bool]: + """Best-effort connectivity diagnostics to distinguish local vs. server issues.""" + results = { + "internet_accessible": False, + "api_accessible": False, + "is_local_issue": False, + "is_api_issue": False, + } + timeout = aiohttp.ClientTimeout(total=5.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + try: + async with session.get("https://www.google.com") as resp: + results["internet_accessible"] = resp.status < 500 + except (ClientError, asyncio.TimeoutError, socket.gaierror): + results["is_local_issue"] = True + return results + + parsed = urlparse(_default_base_url()) + health_url = f"{parsed.scheme}://{parsed.netloc}/health" + with contextlib.suppress(ClientError, asyncio.TimeoutError): + async with session.get(health_url) as resp: + results["api_accessible"] = resp.status < 500 + results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"] + return results + + +def _unpack_tuple(t: tuple) -> tuple[str, Any, str]: + """Normalize (filename, value, content_type).""" + if len(t) == 2: + return t[0], t[1], "application/octet-stream" + if len(t) == 3: + return t[0], t[1], t[2] + raise ValueError("files tuple must be (filename, file[, content_type])") + + +def _join_url(base_url: str, path: str) -> str: + return urljoin(base_url.rstrip("/") + "/", path.lstrip("/")) + + +def _merge_headers(node_cls: type[IO.ComfyNode], endpoint_headers: dict[str, str]) -> dict[str, str]: + headers = {"Accept": "*/*"} + headers.update(_get_auth_header(node_cls)) + if endpoint_headers: + headers.update(endpoint_headers) + return headers + + +def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]: + params = dict(endpoint_params or {}) + if method.upper() == "GET" and data: + for k, v in data.items(): + if v is not None: + params[k] = v + return params + + +def _friendly_http_message(status: int, body: Any) -> str: + if status == 401: + return "Unauthorized: Please login first to use this node." + if status == 402: + return "Payment Required: Please add credits to your account to use this node." + if status == 409: + return "There is a problem with your account. Please contact support@comfy.org." + if status == 429: + return "Rate Limit Exceeded: Please try again later." + try: + if isinstance(body, dict): + err = body.get("error") + if isinstance(err, dict): + msg = err.get("message") + typ = err.get("type") + if msg and typ: + return f"API Error: {msg} (Type: {typ})" + if msg: + return f"API Error: {msg}" + return f"API Error: {json.dumps(body)}" + else: + txt = str(body) + if len(txt) <= 200: + return f"API Error (raw): {txt}" + return f"API Error (status {status})" + except Exception: + return f"HTTP {status}: Unknown error" + + +def _generate_operation_id(method: str, path: str, attempt: int) -> str: + slug = path.strip("/").replace("/", "_") or "op" + return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" + + +def _snapshot_request_body_for_logging( + content_type: str, + method: str, + data: Optional[dict[str, Any]], + files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]], +) -> Optional[Union[dict[str, Any], str]]: + if method.upper() == "GET": + return None + if content_type == "multipart/form-data": + form_fields = sorted([k for k, v in (data or {}).items() if v is not None]) + file_fields: list[dict[str, str]] = [] + if files: + file_iter = files if isinstance(files, list) else list(files.items()) + for field_name, file_obj in file_iter: + if file_obj is None: + continue + if isinstance(file_obj, tuple): + filename = file_obj[0] + else: + filename = getattr(file_obj, "name", field_name) + file_fields.append({"field": field_name, "filename": str(filename or "")}) + return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields} + if content_type == "application/x-www-form-urlencoded": + return data or {} + return data or {} + + +async def _request_base(cfg: _RequestConfig, expect_binary: bool): + """Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors.""" + url = _join_url(_default_base_url(), cfg.endpoint.path) + method = cfg.endpoint.method + params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None) + + async def _monitor(stop_evt: asyncio.Event, start_ts: float): + """Every second: update elapsed time and signal interruption.""" + try: + while not stop_evt.is_set(): + if _is_processing_interrupted(): + return + if cfg.monitor_progress: + _display_time_progress( + cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total + ) + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return # normal shutdown + + start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic() + attempt = 0 + delay = cfg.retry_delay + operation_succeeded: bool = False + final_elapsed_seconds: Optional[int] = None + while True: + attempt += 1 + stop_event = asyncio.Event() + monitor_task: Optional[asyncio.Task] = None + sess: Optional[aiohttp.ClientSession] = None + + operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) + logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) + + payload_headers = _merge_headers(cfg.node_cls, cfg.endpoint.headers) + payload_kw: dict[str, Any] = {"headers": payload_headers} + if method == "GET": + payload_headers.pop("Content-Type", None) + request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files) + try: + if cfg.monitor_progress: + monitor_task = asyncio.create_task(_monitor(stop_event, start_time)) + + timeout = aiohttp.ClientTimeout(total=cfg.timeout) + sess = aiohttp.ClientSession(timeout=timeout) + + if cfg.content_type == "multipart/form-data" and method != "GET": + # aiohttp will set Content-Type boundary; remove any fixed Content-Type + payload_headers.pop("Content-Type", None) + if cfg.multipart_parser and cfg.data: + form = cfg.multipart_parser(cfg.data) + if not isinstance(form, aiohttp.FormData): + raise ValueError("multipart_parser must return aiohttp.FormData") + else: + form = aiohttp.FormData(default_to_multipart=True) + if cfg.data: + for k, v in cfg.data.items(): + if v is None: + continue + form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) + if cfg.files: + file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items() + for field_name, file_obj in file_iter: + if file_obj is None: + continue + if isinstance(file_obj, tuple): + filename, file_value, content_type = _unpack_tuple(file_obj) + else: + filename = getattr(file_obj, "name", field_name) + file_value = file_obj + content_type = "application/octet-stream" + # Attempt to rewind BytesIO for retries + if isinstance(file_value, BytesIO): + with contextlib.suppress(Exception): + file_value.seek(0) + form.add_field(field_name, file_value, filename=filename, content_type=content_type) + payload_kw["data"] = form # do not send body on GET + elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET": + payload_headers["Content-Type"] = "application/x-www-form-urlencoded" + payload_kw["data"] = cfg.data or {} + elif method != "GET": + payload_headers["Content-Type"] = "application/json" + payload_kw["json"] = cfg.data or {} + + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + ) + except Exception as _log_e: + logging.debug("[DEBUG] request logging failed: %s", _log_e) + + # Compose the HTTP request coroutine + req_coro = sess.request(method, url, params=params, **payload_kw) + req_task = asyncio.create_task(req_coro) + + # Race: request vs. monitor (interruption) + tasks = {req_task} + if monitor_task: + tasks.add(monitor_task) + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task and monitor_task in done: + # Interrupted – cancel the request and abort + if req_task in pending: + req_task.cancel() + raise ProcessingInterrupted("Task cancelled") + + # Otherwise, request finished + resp = await req_task + async with resp: + if resp.status >= 400: + try: + body = await resp.json() + except (ContentTypeError, json.JSONDecodeError): + body = await resp.text() + # Retryable? + if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries: + logging.warning( + "HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).", + method, + url, + resp.status, + delay, + attempt, + cfg.max_retries, + ) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=_friendly_http_message(resp.status, body), + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + + await _sleep_with_interrupt( + delay, + cfg.node_cls, + cfg.wait_label if cfg.monitor_progress else None, + start_time if cfg.monitor_progress else None, + cfg.estimated_total, + display_callback=_display_time_progress if cfg.monitor_progress else None, + ) + delay *= cfg.retry_backoff + continue + msg = _friendly_http_message(resp.status, body) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=msg, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + raise Exception(msg) + + # Success + if expect_binary: + # Read stream in chunks so that cancellation is fast when user interrupts + buff = bytearray() + last_tick = time.monotonic() + async for chunk in resp.content.iter_chunked(64 * 1024): + buff.extend(chunk) + now = time.monotonic() + if now - last_tick >= 1.0: + last_tick = now + if _is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + if cfg.monitor_progress: + _display_time_progress( + cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total + ) + bytes_payload = bytes(buff) + operation_succeeded = True + final_elapsed_seconds = int(time.monotonic() - start_time) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=bytes_payload, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + return bytes_payload + else: + try: + payload = await resp.json() + response_content_to_log: Any = payload + except (ContentTypeError, json.JSONDecodeError): + text = await resp.text() + try: + payload = json.loads(text) if text else {} + except json.JSONDecodeError: + payload = {"_raw": text} + response_content_to_log = payload if isinstance(payload, dict) else text + operation_succeeded = True + final_elapsed_seconds = int(time.monotonic() - start_time) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=response_content_to_log, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + return payload + + except ProcessingInterrupted: + logging.debug("Polling was interrupted by user") + raise + except (ClientError, asyncio.TimeoutError, socket.gaierror) as e: + # Retry transient connection issues + if attempt <= cfg.max_retries: + logging.warning( + "Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s", + method, url, delay, attempt, cfg.max_retries, str(e) + ) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + except Exception as _log_e: + logging.debug("[DEBUG] request error logging failed: %s", _log_e) + await _sleep_with_interrupt( + delay, + cfg.node_cls, + cfg.wait_label if cfg.monitor_progress else None, + start_time if cfg.monitor_progress else None, + cfg.estimated_total, + display_callback=_display_time_progress if cfg.monitor_progress else None, + ) + delay *= cfg.retry_backoff + continue + diag = await _diagnose_connectivity() + if diag.get("is_local_issue"): + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"LocalNetworkError: {str(e)}", + ) + except Exception as _log_e: + logging.debug("[DEBUG] final error logging failed: %s", _log_e) + raise LocalNetworkError( + "Unable to connect to the API server due to local network issues. " + "Please check your internet connection and try again." + ) from e + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"ApiServerError: {str(e)}", + ) + except Exception as _log_e: + logging.debug("[DEBUG] final error logging failed: %s", _log_e) + raise ApiServerError( + f"The API server at {_default_base_url()} is currently unreachable. " + f"The service may be experiencing issues." + ) from e + finally: + stop_event.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if sess: + with contextlib.suppress(Exception): + await sess.close() + if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success: + _display_time_progress( + cfg.node_cls, + label=cfg.final_label_on_success, + elapsed_seconds=( + final_elapsed_seconds + if final_elapsed_seconds is not None + else int(time.monotonic() - start_time) + ), + estimated_total=cfg.estimated_total, + price=None, + is_queued=False, + processing_elapsed_seconds=final_elapsed_seconds, + ) + + +def _validate_or_raise(response_model: Type[M], payload: Any) -> M: + try: + return response_model.model_validate(payload) + except Exception as e: + logging.error( + "Response validation failed for %s: %s", + getattr(response_model, "__name__", response_model), + e, + ) + raise Exception( + f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}" + ) from e + + +def _wrap_model_extractor( + response_model: Type[M], + extractor: Optional[Callable[[M], Any]], +) -> Optional[Callable[[dict[str, Any]], Any]]: + """Wrap a typed extractor so it can be used by the dict-based poller. + Validates the dict into `response_model` before invoking `extractor`. + Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating + the same response for multiple extractors in a single poll attempt. + """ + if extractor is None: + return None + _cache: dict[int, M] = {} + + def _wrapped(d: dict[str, Any]) -> Any: + try: + key = id(d) + model = _cache.get(key) + if model is None: + model = response_model.model_validate(d) + _cache[key] = model + return extractor(model) + except Exception as e: + logging.error("Extractor failed (typed -> dict wrapper): %s", e) + raise + + return _wrapped diff --git a/comfy_api_nodes/util/common_exceptions.py b/comfy_api_nodes/util/common_exceptions.py new file mode 100644 index 000000000000..0606a4407007 --- /dev/null +++ b/comfy_api_nodes/util/common_exceptions.py @@ -0,0 +1,14 @@ +class NetworkError(Exception): + """Base exception for network-related errors with diagnostic information.""" + + +class LocalNetworkError(NetworkError): + """Exception raised when local network connectivity issues are detected.""" + + +class ApiServerError(NetworkError): + """Exception raised when the API server is unreachable but internet is working.""" + + +class ProcessingInterrupted(Exception): + """Operation was interrupted by user/runtime via processing_interrupted().""" diff --git a/comfy_api_nodes/util/storage_helpers.py b/comfy_api_nodes/util/storage_helpers.py new file mode 100644 index 000000000000..d8af624efe65 --- /dev/null +++ b/comfy_api_nodes/util/storage_helpers.py @@ -0,0 +1,272 @@ +import uuid +import asyncio +import contextlib +from io import BytesIO +import logging +import time +from typing import Optional, Union + +import aiohttp +import torch +from pydantic import BaseModel, Field + +from comfy_api.latest import IO +from urllib.parse import urlparse +from .api_client import ( + ApiEndpoint, + sync_op_pydantic, + _display_time_progress, + _diagnose_connectivity, +) + +from comfy_api_nodes.apis import request_logger +from comfy_api_nodes.apinode_utils import tensor_to_bytesio +from ._helpers import _sleep_with_interrupt, _is_processing_interrupted +from .common_exceptions import ProcessingInterrupted, LocalNetworkError, ApiServerError + + +class UploadRequest(BaseModel): + file_name: str = Field(..., description="Filename to upload") + content_type: Optional[str] = Field( + None, + description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", + ) + + +class UploadResponse(BaseModel): + download_url: str = Field(..., description="URL to GET uploaded file") + upload_url: str = Field(..., description="URL to PUT file to upload") + + +async def upload_images_to_comfyapi( + cls: type[IO.ComfyNode], + image: torch.Tensor, + *, + max_images: int = 8, + mime_type: Optional[str] = None, + status_update: bool = True, +) -> list[str]: + """ + Uploads images to ComfyUI API and returns download URLs. + To upload multiple images, stack them in the batch dimension first. + """ + # if batch, try to upload each file if max_images is greater than 0 + download_urls: list[str] = [] + is_batch = len(image.shape) > 3 + batch_len = image.shape[0] if is_batch else 1 + + for idx in range(min(batch_len, max_images)): + tensor = image[idx] if is_batch else image + img_io = tensor_to_bytesio(tensor, mime_type=mime_type) + url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, status_update) + download_urls.append(url) + return download_urls + + +async def upload_file_to_comfyapi( + cls: type[IO.ComfyNode], + file_bytes_io: BytesIO, + filename: str, + upload_mime_type: Optional[str], + status_update: bool = True, +) -> str: + """Uploads a single file to ComfyUI API and returns its download URL.""" + if upload_mime_type is None: + request_object = UploadRequest(file_name=filename) + else: + request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) + create_resp = await sync_op_pydantic( + cls, + endpoint=ApiEndpoint(path="/customers/storage", method="POST"), + data=request_object, + response_model=UploadResponse, + final_label_on_success=None, + monitor_progress=False, + ) + await upload_file( + cls, create_resp.upload_url, + file_bytes_io, + content_type=upload_mime_type, + wait_label="Uploading" if status_update else None, + ) + return create_resp.download_url + + +async def upload_file( + cls: type[IO.ComfyNode], + upload_url: str, + file: Union[BytesIO, str], + *, + content_type: Optional[str] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: Optional[str] = None, +) -> None: + """ + Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption. + + Args: + cls: Node class (provides auth context + UI progress hooks). + upload_url: Pre-signed PUT URL. + file: BytesIO or path string. + content_type: Explicit MIME type. If None, we *suppress* Content-Type. + max_retries: Maximum retry attempts. + retry_delay: Initial delay in seconds. + retry_backoff: Exponential backoff factor. + wait_label: Progress label shown in Comfy UI. + + Raises: + ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception + """ + if isinstance(file, BytesIO): + with contextlib.suppress(Exception): + file.seek(0) + data = file.read() + elif isinstance(file, str): + with open(file, "rb") as f: + data = f.read() + else: + raise ValueError("file must be a BytesIO or a filesystem path string") + + headers: dict[str, str] = {} + skip_auto_headers: set[str] = set() + if content_type: + headers["Content-Type"] = content_type + else: + skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request + + attempt = 0 + delay = retry_delay + start_ts = time.monotonic() + op_uuid = uuid.uuid4().hex[:8] + while True: + attempt += 1 + operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid) + timeout = aiohttp.ClientTimeout(total=None) + stop_evt = asyncio.Event() + + async def _monitor(): + try: + while not stop_evt.is_set(): + if _is_processing_interrupted(): + return + if wait_label: + _display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None) + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return + + monitor_task = asyncio.create_task(_monitor()) + sess: Optional[aiohttp.ClientSession] = None + try: + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers or None, + request_params=None, + request_data=f"[File data {len(data)} bytes]", + ) + except Exception as e: + logging.debug("[DEBUG] upload request logging failed: %s", e) + + sess = aiohttp.ClientSession(timeout=timeout) + req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers) + req_task = asyncio.create_task(req) + + done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task in done and req_task in pending: + req_task.cancel() + raise ProcessingInterrupted("Upload cancelled") + + resp = await req_task + async with resp: + if resp.status >= 400: + with contextlib.suppress(Exception): + try: + body = await resp.json() + except Exception: + body = await resp.text() + msg = f"Upload failed with status {resp.status}" + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=msg, + ) + if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries: + await _sleep_with_interrupt( + delay, + cls, + wait_label, + start_ts, + None, + display_callback=_display_time_progress if wait_label else None, + ) + delay *= retry_backoff + continue + raise Exception(f"Failed to upload (HTTP {resp.status}).") + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content="File uploaded successfully.", + ) + except Exception as e: + logging.debug("[DEBUG] upload response logging failed: %s", e) + return + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + if attempt <= max_retries: + with contextlib.suppress(Exception): + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers or None, + request_data=f"[File data {len(data)} bytes]", + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + await _sleep_with_interrupt( + delay, + cls, + wait_label, + start_ts, + None, + display_callback=_display_time_progress if wait_label else None, + ) + delay *= retry_backoff + continue + + diag = await _diagnose_connectivity() + if diag.get("is_local_issue"): + raise LocalNetworkError( + "Unable to connect to the network. Please check your internet connection and try again." + ) from e + raise ApiServerError("The API service appears unreachable at this time.") from e + finally: + stop_evt.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if sess: + with contextlib.suppress(Exception): + await sess.close() + + +def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str: + try: + parsed = urlparse(url) + slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_") + except Exception: + slug = "upload" + return f"{method}_{slug}_{op_uuid}_try{attempt}" diff --git a/pyproject.toml b/pyproject.toml index 653604e24f0e..fcbcb3dd919e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,8 @@ messages_control.disable = [ "too-many-branches", "too-many-locals", "too-many-arguments", + "too-many-return-statements", + "too-many-nested-blocks", "duplicate-code", "abstract-method", "superfluous-parens", From 462ce028beabd5d1184b4be19980a5a995958936 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Tue, 14 Oct 2025 19:49:47 +0300 Subject: [PATCH 2/8] feat(api-nodes): implement new API client for V3 nodes --- comfy_api_nodes/nodes_bytedance.py | 221 ++---- comfy_api_nodes/util/__init__.py | 23 + comfy_api_nodes/util/_helpers.py | 58 ++ comfy_api_nodes/util/api_client.py | 915 ++++++++++++++++++++++ comfy_api_nodes/util/common_exceptions.py | 14 + comfy_api_nodes/util/conversions.py | 25 + comfy_api_nodes/util/download_helpers.py | 246 ++++++ comfy_api_nodes/util/upload_helpers.py | 272 +++++++ pyproject.toml | 2 + 9 files changed, 1614 insertions(+), 162 deletions(-) create mode 100644 comfy_api_nodes/util/_helpers.py create mode 100644 comfy_api_nodes/util/api_client.py create mode 100644 comfy_api_nodes/util/common_exceptions.py create mode 100644 comfy_api_nodes/util/conversions.py create mode 100644 comfy_api_nodes/util/download_helpers.py create mode 100644 comfy_api_nodes/util/upload_helpers.py diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index f3d3f8d3eeab..f2e3e9027b8c 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1,7 +1,7 @@ import logging import math from enum import Enum -from typing import Literal, Optional, Type, Union +from typing import Literal, Optional, Union from typing_extensions import override import torch @@ -13,18 +13,15 @@ get_number_of_images, validate_image_dimensions, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - EmptyRequest, - HttpMethod, - SynchronousOperation, - PollingOperation, - T, + sync_op_pydantic, + poll_op_pydantic, + upload_images_to_comfyapi, ) from comfy_api_nodes.apinode_utils import ( download_url_to_image_tensor, download_url_to_video_output, - upload_images_to_comfyapi, validate_string, image_tensor_pair_to_batch, ) @@ -208,35 +205,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N return None -async def poll_until_finished( - auth_kwargs: dict[str, str], - task_id: str, - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> TaskStatusResponse: - """Polls the ByteDance API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - completed_statuses=[ - "succeeded", - ], - failed_statuses=[ - "cancelled", - "failed", - ], - status_extractor=lambda response: response.status, - auth_kwargs=auth_kwargs, - result_url_extractor=get_video_url_from_task_status, - estimated_duration=estimated_duration, - node_id=node_id, - ).execute() - - class ByteDanceImageNode(IO.ComfyNode): @classmethod @@ -353,20 +321,12 @@ async def execute( guidance_scale=guidance_scale, watermark=watermark, ) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_IMAGE_ENDPOINT, - method=HttpMethod.POST, - request_model=Text2ImageTaskCreationRequest, - response_model=ImageTaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() + response = await sync_op_pydantic( + cls, + endpoint=ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + data=payload, + response_model=ImageTaskCreationResponse, + ) return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) @@ -449,16 +409,7 @@ async def execute( if get_number_of_images(image) != 1: raise ValueError("Exactly one input image is required.") validate_image_aspect_ratio_range(image, (1, 3), (3, 1)) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - source_url = (await upload_images_to_comfyapi( - image, - max_images=1, - mime_type="image/png", - auth_kwargs=auth_kwargs, - ))[0] + source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0] payload = Image2ImageTaskCreationRequest( model=model, prompt=prompt, @@ -467,16 +418,12 @@ async def execute( guidance_scale=guidance_scale, watermark=watermark, ) - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_IMAGE_ENDPOINT, - method=HttpMethod.POST, - request_model=Image2ImageTaskCreationRequest, - response_model=ImageTaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() + response = await sync_op_pydantic( + cls, + endpoint=ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + data=payload, + response_model=ImageTaskCreationResponse, + ) return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) @@ -621,41 +568,31 @@ async def execute( raise ValueError( "The maximum number of generated images plus the number of reference images cannot exceed 15." ) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } reference_images_urls = [] if n_input_images: for i in image: validate_image_aspect_ratio_range(i, (1, 3), (3, 1)) - reference_images_urls = (await upload_images_to_comfyapi( + reference_images_urls = await upload_images_to_comfyapi( + cls, image, max_images=n_input_images, mime_type="image/png", - auth_kwargs=auth_kwargs, - )) - payload = Seedream4TaskCreationRequest( - model=model, - prompt=prompt, - image=reference_images_urls, - size=f"{w}x{h}", - seed=seed, - sequential_image_generation=sequential_image_generation, - sequential_image_generation_options=Seedream4Options(max_images=max_images), - watermark=watermark, - ) - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_IMAGE_ENDPOINT, - method=HttpMethod.POST, - request_model=Seedream4TaskCreationRequest, - response_model=ImageTaskCreationResponse, + ) + response = await sync_op_pydantic( + cls, + endpoint=ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + response_model=ImageTaskCreationResponse, + data=Seedream4TaskCreationRequest( + model=model, + prompt=prompt, + image=reference_images_urls, + size=f"{w}x{h}", + seed=seed, + sequential_image_generation=sequential_image_generation, + sequential_image_generation_options=Seedream4Options(max_images=max_images), + watermark=watermark, ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - + ) if len(response.data) == 1: return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d] @@ -764,19 +701,9 @@ async def execute( f"--camerafixed {str(camera_fixed).lower()} " f"--watermark {str(watermark).lower()}" ) - - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } return await process_video_task( - request_model=Text2VideoTaskCreationRequest, - payload=Text2VideoTaskCreationRequest( - model=model, - content=[TaskTextContent(text=prompt)], - ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, + cls, + payload=Text2VideoTaskCreationRequest(model=model, content=[TaskTextContent(text=prompt)]), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -879,13 +806,7 @@ async def execute( validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth_kwargs))[0] - + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] prompt = ( f"{prompt} " f"--resolution {resolution} " @@ -897,13 +818,11 @@ async def execute( ) return await process_video_task( - request_model=Image2VideoTaskCreationRequest, + cls, payload=Image2VideoTaskCreationRequest( model=model, content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))], ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -1012,16 +931,11 @@ async def execute( validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - download_urls = await upload_images_to_comfyapi( + cls, image_tensor_pair_to_batch(first_frame, last_frame), max_images=2, mime_type="image/png", - auth_kwargs=auth_kwargs, ) prompt = ( @@ -1035,7 +949,7 @@ async def execute( ) return await process_video_task( - request_model=Image2VideoTaskCreationRequest, + cls, payload=Image2VideoTaskCreationRequest( model=model, content=[ @@ -1044,8 +958,6 @@ async def execute( TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"), ], ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -1141,15 +1053,7 @@ async def execute( validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - image_urls = await upload_images_to_comfyapi( - images, max_images=4, mime_type="image/png", auth_kwargs=auth_kwargs - ) - + image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png") prompt = ( f"{prompt} " f"--resolution {resolution} " @@ -1163,39 +1067,32 @@ async def execute( *[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls] ] return await process_video_task( - request_model=Image2VideoTaskCreationRequest, - payload=Image2VideoTaskCreationRequest( - model=model, - content=x, - ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, + cls, + payload=Image2VideoTaskCreationRequest(model=model, content=x), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) async def process_video_task( - request_model: Type[T], + cls: type[IO.ComfyNode], payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], - auth_kwargs: dict, - node_id: str, estimated_duration: Optional[int], ) -> IO.NodeOutput: - initial_response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_TASK_ENDPOINT, - method=HttpMethod.POST, - request_model=request_model, - response_model=TaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - response = await poll_until_finished( - auth_kwargs, - initial_response.id, + initial_response = await sync_op_pydantic( + cls, + endpoint=ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), + data=payload, + response_model=TaskCreationResponse, + ) + response = await poll_op_pydantic( + cls, + poll_endpoint=ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"), + completed_statuses=["succeeded"], + failed_statuses=["cancelled", "failed"], + queued_states=["queued"], + status_extractor=lambda r: r.status, estimated_duration=estimated_duration, - node_id=node_id, + response_model=TaskStatusResponse, ) return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response))) diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index e69de29bb2d1..fe3cda258d6d 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -0,0 +1,23 @@ +from .api_client import ApiEndpoint, sync_op_pydantic, poll_op_pydantic, sync_op, poll_op +from .download_helpers import ( + download_url_to_bytesio, + download_url_to_image_tensor, + bytesio_to_image_tensor, +) +from .upload_helpers import ( + upload_file_to_comfyapi, + upload_images_to_comfyapi, +) + +__all__ = [ + "ApiEndpoint", + "poll_op", + "sync_op", + "poll_op_pydantic", + "sync_op_pydantic", + "upload_file_to_comfyapi", + "upload_images_to_comfyapi", + "download_url_to_bytesio", + "download_url_to_image_tensor", + "bytesio_to_image_tensor", +] diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py new file mode 100644 index 000000000000..1bf951adc48e --- /dev/null +++ b/comfy_api_nodes/util/_helpers.py @@ -0,0 +1,58 @@ +import asyncio +import contextlib +import time +from typing import Optional, Callable + +from comfy_api.latest import IO +from comfy.cli_args import args +from comfy.model_management import processing_interrupted + +from .common_exceptions import ProcessingInterrupted + + +def _is_processing_interrupted() -> bool: + """Return True if user/runtime requested interruption.""" + return processing_interrupted() + + +def _get_node_id(node_cls: type[IO.ComfyNode]) -> str: + return node_cls.hidden.unique_id + + +def _get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]: + if node_cls.hidden.auth_token_comfy_org: + return {"Authorization": f"Bearer {node_cls.hidden.auth_token_comfy_org}"} + if node_cls.hidden.api_key_comfy_org: + return {"X-API-KEY": node_cls.hidden.api_key_comfy_org} + return {} + + +def _default_base_url() -> str: + return getattr(args, "comfy_api_base", "https://api.comfy.org") + + +async def _sleep_with_interrupt( + seconds: float, + node_cls: type[IO.ComfyNode], + label: Optional[str] = None, + start_ts: Optional[float] = None, + estimated_total: Optional[int] = None, + *, + display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None, +): + """ + Sleep in 1s slices while: + - Checking for interruption (raises ProcessingInterrupted). + - Optionally emitting time progress via display_callback (if provided). + """ + end = time.monotonic() + seconds + while True: + if _is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + now = time.monotonic() + if start_ts is not None and label and display_callback: + with contextlib.suppress(Exception): + display_callback(node_cls, label, int(now - start_ts), estimated_total) + if now >= end: + break + await asyncio.sleep(min(1.0, end - now)) diff --git a/comfy_api_nodes/util/api_client.py b/comfy_api_nodes/util/api_client.py new file mode 100644 index 000000000000..95614d6b6987 --- /dev/null +++ b/comfy_api_nodes/util/api_client.py @@ -0,0 +1,915 @@ +import asyncio +import contextlib +import json +import logging +import socket +import time +import uuid +from dataclasses import dataclass +from enum import Enum +from io import BytesIO +from typing import Any, Callable, Optional, Union, Type, TypeVar, Literal + +import aiohttp +from aiohttp.client_exceptions import ClientError, ContentTypeError +from comfy_api.latest import IO +from comfy import utils +from pydantic import BaseModel +from server import PromptServer +from urllib.parse import urljoin, urlparse + +from comfy_api_nodes.apis import request_logger +from .common_exceptions import ProcessingInterrupted, LocalNetworkError, ApiServerError +from ._helpers import ( + _is_processing_interrupted, + _get_node_id, + _get_auth_header, + _default_base_url, + _sleep_with_interrupt, +) + + +M = TypeVar("M", bound=BaseModel) + + +class ApiEndpoint: + def __init__( + self, + path: str, + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET", + *, + query_params: Optional[dict[str, Any]] = None, + headers: Optional[dict[str, str]] = None, + ): + self.path = path + self.method = method + self.query_params = query_params or {} + self.headers = headers or {} + + +@dataclass +class _RequestConfig: + node_cls: type[IO.ComfyNode] + endpoint: ApiEndpoint + timeout: float + content_type: str + data: Optional[dict[str, Any]] + files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] + multipart_parser: Optional[Callable] + max_retries: int + retry_delay: float + retry_backoff: float + wait_label: str = "Waiting" + monitor_progress: bool = True + estimated_total: Optional[int] = None + final_label_on_success: Optional[str] = "Completed" + progress_origin_ts: Optional[float] = None + + +@dataclass +class _PollUIState: + started: float + status_label: str = "Queued" + is_queued: bool = True + price: Optional[float] = None + estimated_duration: Optional[int] = None + base_processing_elapsed: float = 0.0 # sum of completed active intervals + active_since: Optional[float] = None # start time of current active interval (None if queued) + + +_RETRY_STATUS = {408, 429, 500, 502, 503, 504} + + +async def sync_op_pydantic( + cls: type[IO.ComfyNode], + endpoint: ApiEndpoint, + *, + response_model: Type[M], + data: Optional[BaseModel] = None, + files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, + content_type: str = "application/json", + timeout: float = 3600.0, + multipart_parser: Optional[Callable] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: str = "Waiting for server", + estimated_total: Optional[int] = None, + final_label_on_success: Optional[str] = "Completed", + progress_origin_ts: Optional[float] = None, + monitor_progress: bool = True, +) -> M: + raw = await sync_op( + cls, + endpoint, + data=data, + files=files, + content_type=content_type, + timeout=timeout, + multipart_parser=multipart_parser, + max_retries=max_retries, + retry_delay=retry_delay, + retry_backoff=retry_backoff, + wait_label=wait_label, + estimated_total=estimated_total, + as_binary=False, + final_label_on_success=final_label_on_success, + progress_origin_ts=progress_origin_ts, + monitor_progress=monitor_progress, + ) + if not isinstance(raw, dict): + raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + return _validate_or_raise(response_model, raw) + + +async def poll_op_pydantic( + cls: type[IO.ComfyNode], + *, + poll_endpoint: ApiEndpoint, + response_model: Type[M], + status_extractor: Callable[[M], Optional[str]], + progress_extractor: Optional[Callable[[M], Optional[int]]] = None, + price_extractor: Optional[Callable[[M], Optional[float]]] = None, + completed_statuses: list[str], + failed_statuses: list[str], + queued_states: Optional[list[str]] = None, + poll_interval: float = 5.0, + max_poll_attempts: int = 120, + timeout_per_poll: float = 120.0, + max_retries_per_poll: int = 3, + retry_delay_per_poll: float = 1.0, + retry_backoff_per_poll: float = 2.0, + estimated_duration: Optional[int] = None, + cancel_endpoint: Optional[ApiEndpoint] = None, + cancel_timeout: float = 10.0, +) -> M: + raw = await poll_op( + cls, + poll_endpoint=poll_endpoint, + status_extractor=_wrap_model_extractor(response_model, status_extractor), + progress_extractor=_wrap_model_extractor(response_model, progress_extractor), + price_extractor=_wrap_model_extractor(response_model, price_extractor), + completed_statuses=completed_statuses, + failed_statuses=failed_statuses, + queued_states=queued_states, + poll_interval=poll_interval, + max_poll_attempts=max_poll_attempts, + timeout_per_poll=timeout_per_poll, + max_retries_per_poll=max_retries_per_poll, + retry_delay_per_poll=retry_delay_per_poll, + retry_backoff_per_poll=retry_backoff_per_poll, + estimated_duration=estimated_duration, + cancel_endpoint=cancel_endpoint, + cancel_timeout=cancel_timeout, + ) + if not isinstance(raw, dict): + raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + return _validate_or_raise(response_model, raw) + + +async def sync_op( + cls: type[IO.ComfyNode], + endpoint: ApiEndpoint, + *, + data: Optional[Union[dict[str, Any], BaseModel]] = None, + files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, + content_type: str = "application/json", + timeout: float = 3600.0, + multipart_parser: Optional[Callable] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: str = "Waiting for server", + estimated_total: Optional[int] = None, + as_binary: bool = False, + final_label_on_success: Optional[str] = "Completed", + progress_origin_ts: Optional[float] = None, + monitor_progress: bool = True, +) -> Union[dict[str, Any], bytes]: + """ + Make a single network request. + - If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON). + - If as_binary=True: returns bytes. + """ + if isinstance(data, BaseModel): + data = data.model_dump(exclude_none=True) + for k, v in list(data.items()): + if isinstance(v, Enum): + data[k] = v.value + cfg = _RequestConfig( + node_cls=cls, + endpoint=endpoint, + timeout=timeout, + content_type=content_type, + data=data, + files=files, + multipart_parser=multipart_parser, + max_retries=max_retries, + retry_delay=retry_delay, + retry_backoff=retry_backoff, + wait_label=wait_label, + monitor_progress=monitor_progress, + estimated_total=estimated_total, + final_label_on_success=final_label_on_success, + progress_origin_ts=progress_origin_ts, + ) + return await _request_base(cfg, expect_binary=as_binary) + + +async def poll_op( + cls: type[IO.ComfyNode], + *, + poll_endpoint: ApiEndpoint, + status_extractor: Callable[[dict[str, Any]], Optional[str]], + progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None, + price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, + completed_statuses: list[str], + failed_statuses: list[str], + queued_states: Optional[list[str]] = None, + poll_interval: float = 5.0, + max_poll_attempts: int = 120, + timeout_per_poll: float = 120.0, + max_retries_per_poll: int = 3, + retry_delay_per_poll: float = 1.0, + retry_backoff_per_poll: float = 2.0, + estimated_duration: Optional[int] = None, + cancel_endpoint: Optional[ApiEndpoint] = None, + cancel_timeout: float = 10.0, +) -> dict[str, Any]: + """ + Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing, + checks interruption every second, and calls Cancel endpoint (if provided) on interruption. + Returns the final JSON response from the poll endpoint. + """ + queued_states = queued_states or [] + started = time.monotonic() + consumed_attempts = 0 # counts only non-queued polls + + progress_bar = utils.ProgressBar(100) if progress_extractor else None + last_progress: Optional[int] = None + + state = _PollUIState(started=started, estimated_duration=estimated_duration) + stop_ticker = asyncio.Event() + + async def _ticker(): + """Emit a UI update every second while polling is in progress.""" + try: + while not stop_ticker.is_set(): + if _is_processing_interrupted(): + break + now = time.monotonic() + proc_elapsed = state.base_processing_elapsed + ( + (now - state.active_since) if state.active_since is not None else 0.0 + ) + _display_time_progress( + cls, + label=state.status_label, + elapsed_seconds=int(now - state.started), + estimated_total=state.estimated_duration, + price=state.price, + is_queued=state.is_queued, + processing_elapsed_seconds=int(proc_elapsed), + ) + await asyncio.sleep(1.0) + except Exception as exc: + logging.debug("Polling ticker exited: %s", exc) + + ticker_task = asyncio.create_task(_ticker()) + try: + while consumed_attempts < max_poll_attempts: + try: + resp_json = await sync_op( + cls, + poll_endpoint, + timeout=timeout_per_poll, + max_retries=max_retries_per_poll, + retry_delay=retry_delay_per_poll, + retry_backoff=retry_backoff_per_poll, + wait_label="Checking", + estimated_total=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + if not isinstance(resp_json, dict): + raise Exception("Polling endpoint returned non-JSON response.") + except ProcessingInterrupted: + if cancel_endpoint: + with contextlib.suppress(Exception): + await sync_op( + cls, + cancel_endpoint, + timeout=cancel_timeout, + max_retries=0, + wait_label="Cancelling task", + estimated_total=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + raise + + try: + status = status_extractor(resp_json) + except Exception as e: + logging.error("Status extraction failed: %s", e) + status = None + + if price_extractor: + new_price = price_extractor(resp_json) + if new_price is not None: + state.price = new_price + + if progress_extractor: + new_progress = progress_extractor(resp_json) + if new_progress is not None and last_progress != new_progress: + progress_bar.update_absolute(new_progress, total=100) + last_progress = new_progress + + now_ts = time.monotonic() + is_queued = status in queued_states + + if is_queued: + if state.active_since is not None: # If we just moved from active -> queued, close the active interval + state.base_processing_elapsed += (now_ts - state.active_since) + state.active_since = None + else: + if state.active_since is None: # If we just moved from queued -> active, open a new active interval + state.active_since = now_ts + + state.is_queued = is_queued + state.status_label = status or ("Queued" if is_queued else "Processing") + if status in completed_statuses: + if state.active_since is not None: + state.base_processing_elapsed += (now_ts - state.active_since) + state.active_since = None + stop_ticker.set() + with contextlib.suppress(Exception): + await ticker_task + + if progress_bar and last_progress != 100: + progress_bar.update_absolute(100, total=100) + + _display_time_progress( + cls, + label=status if status else "Completed", + elapsed_seconds=int(now_ts - started), + estimated_total=estimated_duration, + price=state.price, + is_queued=False, + processing_elapsed_seconds=int(state.base_processing_elapsed), + ) + return resp_json + + if status in failed_statuses: + msg = f"Task failed: {json.dumps(resp_json)}" + logging.error(msg) + raise Exception(msg) + + try: + await _sleep_with_interrupt(poll_interval, cls, None, None, None) + except ProcessingInterrupted: + if cancel_endpoint: + with contextlib.suppress(Exception): + await sync_op( + cls, + cancel_endpoint, + timeout=cancel_timeout, + max_retries=0, + wait_label="Cancelling task", + estimated_total=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + raise + if not is_queued: + consumed_attempts += 1 + + raise Exception( + f"Polling timed out after {max_poll_attempts} non-queued attempts " + f"(~{int(max_poll_attempts * poll_interval)}s of active polling)." + ) + except ProcessingInterrupted: + raise + except (LocalNetworkError, ApiServerError): + raise + except Exception as e: + raise Exception(f"Polling aborted due to error: {e}") from e + finally: + stop_ticker.set() + with contextlib.suppress(Exception): + await ticker_task + + +def _display_text( + node_cls: type[IO.ComfyNode], + text: Optional[str], + *, + status: Optional[str] = None, + price: Optional[float] = None, +) -> None: + display_lines: list[str] = [] + if status: + display_lines.append(f"Status: {status.capitalize()}") + if price is not None: + display_lines.append(f"Price: ${float(price):,.4f}") + if text is not None: + display_lines.append(text) + if display_lines: + PromptServer.instance.send_progress_text("\n".join(display_lines), _get_node_id(node_cls)) + + +def _display_time_progress( + node_cls: type[IO.ComfyNode], + label: str, + elapsed_seconds: int, + estimated_total: Optional[int] = None, + *, + price: Optional[float] = None, + is_queued: Optional[bool] = None, + processing_elapsed_seconds: Optional[int] = None, +) -> None: + if estimated_total is not None and estimated_total > 0 and is_queued is False: + pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds + remaining = max(0, int(estimated_total) - int(pe)) + time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" + else: + time_line = f"Time elapsed: {int(elapsed_seconds)}s" + _display_text(node_cls, time_line, status=label, price=price) + + +async def _diagnose_connectivity() -> dict[str, bool]: + """Best-effort connectivity diagnostics to distinguish local vs. server issues.""" + results = { + "internet_accessible": False, + "api_accessible": False, + "is_local_issue": False, + "is_api_issue": False, + } + timeout = aiohttp.ClientTimeout(total=5.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + try: + async with session.get("https://www.google.com") as resp: + results["internet_accessible"] = resp.status < 500 + except (ClientError, asyncio.TimeoutError, socket.gaierror): + results["is_local_issue"] = True + return results + + parsed = urlparse(_default_base_url()) + health_url = f"{parsed.scheme}://{parsed.netloc}/health" + with contextlib.suppress(ClientError, asyncio.TimeoutError): + async with session.get(health_url) as resp: + results["api_accessible"] = resp.status < 500 + results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"] + return results + + +def _unpack_tuple(t: tuple) -> tuple[str, Any, str]: + """Normalize (filename, value, content_type).""" + if len(t) == 2: + return t[0], t[1], "application/octet-stream" + if len(t) == 3: + return t[0], t[1], t[2] + raise ValueError("files tuple must be (filename, file[, content_type])") + + +def _join_url(base_url: str, path: str) -> str: + return urljoin(base_url.rstrip("/") + "/", path.lstrip("/")) + + +def _merge_headers(node_cls: type[IO.ComfyNode], endpoint_headers: dict[str, str]) -> dict[str, str]: + headers = {"Accept": "*/*"} + headers.update(_get_auth_header(node_cls)) + if endpoint_headers: + headers.update(endpoint_headers) + return headers + + +def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]: + params = dict(endpoint_params or {}) + if method.upper() == "GET" and data: + for k, v in data.items(): + if v is not None: + params[k] = v + return params + + +def _friendly_http_message(status: int, body: Any) -> str: + if status == 401: + return "Unauthorized: Please login first to use this node." + if status == 402: + return "Payment Required: Please add credits to your account to use this node." + if status == 409: + return "There is a problem with your account. Please contact support@comfy.org." + if status == 429: + return "Rate Limit Exceeded: Please try again later." + try: + if isinstance(body, dict): + err = body.get("error") + if isinstance(err, dict): + msg = err.get("message") + typ = err.get("type") + if msg and typ: + return f"API Error: {msg} (Type: {typ})" + if msg: + return f"API Error: {msg}" + return f"API Error: {json.dumps(body)}" + else: + txt = str(body) + if len(txt) <= 200: + return f"API Error (raw): {txt}" + return f"API Error (status {status})" + except Exception: + return f"HTTP {status}: Unknown error" + + +def _generate_operation_id(method: str, path: str, attempt: int) -> str: + slug = path.strip("/").replace("/", "_") or "op" + return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" + + +def _snapshot_request_body_for_logging( + content_type: str, + method: str, + data: Optional[dict[str, Any]], + files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]], +) -> Optional[Union[dict[str, Any], str]]: + if method.upper() == "GET": + return None + if content_type == "multipart/form-data": + form_fields = sorted([k for k, v in (data or {}).items() if v is not None]) + file_fields: list[dict[str, str]] = [] + if files: + file_iter = files if isinstance(files, list) else list(files.items()) + for field_name, file_obj in file_iter: + if file_obj is None: + continue + if isinstance(file_obj, tuple): + filename = file_obj[0] + else: + filename = getattr(file_obj, "name", field_name) + file_fields.append({"field": field_name, "filename": str(filename or "")}) + return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields} + if content_type == "application/x-www-form-urlencoded": + return data or {} + return data or {} + + +async def _request_base(cfg: _RequestConfig, expect_binary: bool): + """Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors.""" + url = _join_url(_default_base_url(), cfg.endpoint.path) + method = cfg.endpoint.method + params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None) + + async def _monitor(stop_evt: asyncio.Event, start_ts: float): + """Every second: update elapsed time and signal interruption.""" + try: + while not stop_evt.is_set(): + if _is_processing_interrupted(): + return + if cfg.monitor_progress: + _display_time_progress( + cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total + ) + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return # normal shutdown + + start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic() + attempt = 0 + delay = cfg.retry_delay + operation_succeeded: bool = False + final_elapsed_seconds: Optional[int] = None + while True: + attempt += 1 + stop_event = asyncio.Event() + monitor_task: Optional[asyncio.Task] = None + sess: Optional[aiohttp.ClientSession] = None + + operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) + logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) + + payload_headers = _merge_headers(cfg.node_cls, cfg.endpoint.headers) + payload_kw: dict[str, Any] = {"headers": payload_headers} + if method == "GET": + payload_headers.pop("Content-Type", None) + request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files) + try: + if cfg.monitor_progress: + monitor_task = asyncio.create_task(_monitor(stop_event, start_time)) + + timeout = aiohttp.ClientTimeout(total=cfg.timeout) + sess = aiohttp.ClientSession(timeout=timeout) + + if cfg.content_type == "multipart/form-data" and method != "GET": + # aiohttp will set Content-Type boundary; remove any fixed Content-Type + payload_headers.pop("Content-Type", None) + if cfg.multipart_parser and cfg.data: + form = cfg.multipart_parser(cfg.data) + if not isinstance(form, aiohttp.FormData): + raise ValueError("multipart_parser must return aiohttp.FormData") + else: + form = aiohttp.FormData(default_to_multipart=True) + if cfg.data: + for k, v in cfg.data.items(): + if v is None: + continue + form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) + if cfg.files: + file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items() + for field_name, file_obj in file_iter: + if file_obj is None: + continue + if isinstance(file_obj, tuple): + filename, file_value, content_type = _unpack_tuple(file_obj) + else: + filename = getattr(file_obj, "name", field_name) + file_value = file_obj + content_type = "application/octet-stream" + # Attempt to rewind BytesIO for retries + if isinstance(file_value, BytesIO): + with contextlib.suppress(Exception): + file_value.seek(0) + form.add_field(field_name, file_value, filename=filename, content_type=content_type) + payload_kw["data"] = form # do not send body on GET + elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET": + payload_headers["Content-Type"] = "application/x-www-form-urlencoded" + payload_kw["data"] = cfg.data or {} + elif method != "GET": + payload_headers["Content-Type"] = "application/json" + payload_kw["json"] = cfg.data or {} + + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + ) + except Exception as _log_e: + logging.debug("[DEBUG] request logging failed: %s", _log_e) + + # Compose the HTTP request coroutine + req_coro = sess.request(method, url, params=params, **payload_kw) + req_task = asyncio.create_task(req_coro) + + # Race: request vs. monitor (interruption) + tasks = {req_task} + if monitor_task: + tasks.add(monitor_task) + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task and monitor_task in done: + # Interrupted – cancel the request and abort + if req_task in pending: + req_task.cancel() + raise ProcessingInterrupted("Task cancelled") + + # Otherwise, request finished + resp = await req_task + async with resp: + if resp.status >= 400: + try: + body = await resp.json() + except (ContentTypeError, json.JSONDecodeError): + body = await resp.text() + # Retryable? + if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries: + logging.warning( + "HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).", + method, + url, + resp.status, + delay, + attempt, + cfg.max_retries, + ) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=_friendly_http_message(resp.status, body), + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + + await _sleep_with_interrupt( + delay, + cfg.node_cls, + cfg.wait_label if cfg.monitor_progress else None, + start_time if cfg.monitor_progress else None, + cfg.estimated_total, + display_callback=_display_time_progress if cfg.monitor_progress else None, + ) + delay *= cfg.retry_backoff + continue + msg = _friendly_http_message(resp.status, body) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=msg, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + raise Exception(msg) + + # Success + if expect_binary: + # Read stream in chunks so that cancellation is fast when user interrupts + buff = bytearray() + last_tick = time.monotonic() + async for chunk in resp.content.iter_chunked(64 * 1024): + buff.extend(chunk) + now = time.monotonic() + if now - last_tick >= 1.0: + last_tick = now + if _is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + if cfg.monitor_progress: + _display_time_progress( + cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total + ) + bytes_payload = bytes(buff) + operation_succeeded = True + final_elapsed_seconds = int(time.monotonic() - start_time) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=bytes_payload, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + return bytes_payload + else: + try: + payload = await resp.json() + response_content_to_log: Any = payload + except (ContentTypeError, json.JSONDecodeError): + text = await resp.text() + try: + payload = json.loads(text) if text else {} + except json.JSONDecodeError: + payload = {"_raw": text} + response_content_to_log = payload if isinstance(payload, dict) else text + operation_succeeded = True + final_elapsed_seconds = int(time.monotonic() - start_time) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=response_content_to_log, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + return payload + + except ProcessingInterrupted: + logging.debug("Polling was interrupted by user") + raise + except (ClientError, asyncio.TimeoutError, socket.gaierror) as e: + # Retry transient connection issues + if attempt <= cfg.max_retries: + logging.warning( + "Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s", + method, url, delay, attempt, cfg.max_retries, str(e) + ) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + except Exception as _log_e: + logging.debug("[DEBUG] request error logging failed: %s", _log_e) + await _sleep_with_interrupt( + delay, + cfg.node_cls, + cfg.wait_label if cfg.monitor_progress else None, + start_time if cfg.monitor_progress else None, + cfg.estimated_total, + display_callback=_display_time_progress if cfg.monitor_progress else None, + ) + delay *= cfg.retry_backoff + continue + diag = await _diagnose_connectivity() + if diag.get("is_local_issue"): + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"LocalNetworkError: {str(e)}", + ) + except Exception as _log_e: + logging.debug("[DEBUG] final error logging failed: %s", _log_e) + raise LocalNetworkError( + "Unable to connect to the API server due to local network issues. " + "Please check your internet connection and try again." + ) from e + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"ApiServerError: {str(e)}", + ) + except Exception as _log_e: + logging.debug("[DEBUG] final error logging failed: %s", _log_e) + raise ApiServerError( + f"The API server at {_default_base_url()} is currently unreachable. " + f"The service may be experiencing issues." + ) from e + finally: + stop_event.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if sess: + with contextlib.suppress(Exception): + await sess.close() + if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success: + _display_time_progress( + cfg.node_cls, + label=cfg.final_label_on_success, + elapsed_seconds=( + final_elapsed_seconds + if final_elapsed_seconds is not None + else int(time.monotonic() - start_time) + ), + estimated_total=cfg.estimated_total, + price=None, + is_queued=False, + processing_elapsed_seconds=final_elapsed_seconds, + ) + + +def _validate_or_raise(response_model: Type[M], payload: Any) -> M: + try: + return response_model.model_validate(payload) + except Exception as e: + logging.error( + "Response validation failed for %s: %s", + getattr(response_model, "__name__", response_model), + e, + ) + raise Exception( + f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}" + ) from e + + +def _wrap_model_extractor( + response_model: Type[M], + extractor: Optional[Callable[[M], Any]], +) -> Optional[Callable[[dict[str, Any]], Any]]: + """Wrap a typed extractor so it can be used by the dict-based poller. + Validates the dict into `response_model` before invoking `extractor`. + Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating + the same response for multiple extractors in a single poll attempt. + """ + if extractor is None: + return None + _cache: dict[int, M] = {} + + def _wrapped(d: dict[str, Any]) -> Any: + try: + key = id(d) + model = _cache.get(key) + if model is None: + model = response_model.model_validate(d) + _cache[key] = model + return extractor(model) + except Exception as e: + logging.error("Extractor failed (typed -> dict wrapper): %s", e) + raise + + return _wrapped diff --git a/comfy_api_nodes/util/common_exceptions.py b/comfy_api_nodes/util/common_exceptions.py new file mode 100644 index 000000000000..0606a4407007 --- /dev/null +++ b/comfy_api_nodes/util/common_exceptions.py @@ -0,0 +1,14 @@ +class NetworkError(Exception): + """Base exception for network-related errors with diagnostic information.""" + + +class LocalNetworkError(NetworkError): + """Exception raised when local network connectivity issues are detected.""" + + +class ApiServerError(NetworkError): + """Exception raised when the API server is unreachable but internet is working.""" + + +class ProcessingInterrupted(Exception): + """Operation was interrupted by user/runtime via processing_interrupted().""" diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py new file mode 100644 index 000000000000..207fe0ef25b6 --- /dev/null +++ b/comfy_api_nodes/util/conversions.py @@ -0,0 +1,25 @@ +from io import BytesIO + +import numpy as np +from PIL import Image +import torch + + +def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: + """Converts image data from BytesIO to a torch.Tensor. + + Args: + image_bytesio: BytesIO object containing the image data. + mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA"). + + Returns: + A torch.Tensor representing the image (1, H, W, C). + + Raises: + PIL.UnidentifiedImageError: If the image data cannot be identified. + ValueError: If the specified mode is invalid. + """ + image = Image.open(image_bytesio) + image = image.convert(mode) + image_array = np.array(image).astype(np.float32) / 255.0 + return torch.from_numpy(image_array).unsqueeze(0) diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py new file mode 100644 index 000000000000..90e127b74433 --- /dev/null +++ b/comfy_api_nodes/util/download_helpers.py @@ -0,0 +1,246 @@ +import asyncio +import contextlib +import logging +import time +import uuid +from io import BytesIO +from typing import Optional, Union, IO +from pathlib import Path + +import aiohttp +import torch +from aiohttp.client_exceptions import ClientError, ContentTypeError +from urllib.parse import urlparse + +from comfy_api_nodes.apis import request_logger + +from ._helpers import _is_processing_interrupted +from .common_exceptions import ProcessingInterrupted, LocalNetworkError, ApiServerError +from .api_client import _diagnose_connectivity +from .conversions import bytesio_to_image_tensor + + +_RETRY_STATUS = {408, 429, 500, 502, 503, 504} + + +async def download_url_to_bytesio( + url: str, + timeout: Optional[float] = None, + *, + dest: Optional[Union[BytesIO, IO[bytes], str, Path]] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, +) -> None: + """Stream-download a URL into memory or to a provided destination. + + Raises: + ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors) + """ + attempt = 0 + delay = retry_delay + + while True: + attempt += 1 + op_id = _generate_operation_id("GET", url, attempt) + timeout_cfg = aiohttp.ClientTimeout(total=timeout) + stop_evt = asyncio.Event() + + async def _monitor(): + try: + while not stop_evt.is_set(): + if _is_processing_interrupted(): + return + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return + + monitor_task: Optional[asyncio.Task] = None + sess: Optional[aiohttp.ClientSession] = None + + # Open file path if a path was provided + is_path_sink = isinstance(dest, (str, Path)) + fhandle = None + try: + try: + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + ) + except Exception as e: + logging.debug("[DEBUG] download request logging failed: %s", e) + + monitor_task = asyncio.create_task(_monitor()) + sess = aiohttp.ClientSession(timeout=timeout_cfg) + req_task = asyncio.create_task(sess.get(url)) + + done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) + + # Interruption wins the race + if monitor_task in done and req_task in pending: + req_task.cancel() + raise ProcessingInterrupted("Task cancelled") + + resp = await req_task + async with resp: + if resp.status >= 400: + # Attempt to capture body for logging (do not log huge binaries) + with contextlib.suppress(Exception): + try: + body = await resp.json() + except (ContentTypeError, ValueError): + text = await resp.text() + body = text if len(text) <= 4096 else f"[text {len(text)} bytes]" + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=f"HTTP {resp.status}", + ) + + if resp.status in _RETRY_STATUS and attempt <= max_retries: + await _sleep_with_cancel(delay) + delay *= retry_backoff + continue + raise Exception(f"Failed to download (HTTP {resp.status}).") + + # Prepare path sink if needed + if is_path_sink: + p = Path(str(dest)) + with contextlib.suppress(Exception): + p.parent.mkdir(parents=True, exist_ok=True) + fhandle = open(p, "wb") + sink = fhandle + else: + sink = dest # BytesIO or file-like + + # Stream body in chunks to sink with cancellation checks + written = 0 + last_tick = time.monotonic() + async for chunk in resp.content.iter_chunked(1024 * 1024): + sink.write(chunk) + written += len(chunk) + now = time.monotonic() + if now - last_tick >= 1.0: + last_tick = now + if _is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + + if isinstance(dest, BytesIO): + dest.seek(0) + + try: + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=f"[streamed {written} bytes to dest]", + ) + except Exception as e: + logging.debug("[DEBUG] download response logging failed: %s", e) + return + except ProcessingInterrupted: + logging.debug("Download was interrupted by user") + raise + except (ClientError, asyncio.TimeoutError) as e: + if attempt <= max_retries: + with contextlib.suppress(Exception): + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + await _sleep_with_cancel(delay) + delay *= retry_backoff + continue + + diag = await _diagnose_connectivity() + if diag.get("is_local_issue"): + raise LocalNetworkError( + "Unable to connect to the network. Please check your internet connection and try again." + ) from e + raise ApiServerError("The remote service appears unreachable at this time.") from e + finally: + with contextlib.suppress(Exception): + if fhandle: + fhandle.flush() + fhandle.close() + stop_evt.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if sess: + with contextlib.suppress(Exception): + await sess.close() + + +async def download_url_to_image_tensor( + url: str, + timeout: int = None, + auth_kwargs: Optional[dict[str, str]] = None, + *, + dest: Optional[Union[BytesIO, IO[bytes], str, Path]] = None, + mode: str = "RGBA", +) -> torch.Tensor: + """ + Download image and decode to tensor. Supports streaming `dest` like util version. + """ + if dest is None: + bio = await download_url_to_bytesio(url, timeout, auth_kwargs, dest=None) + return bytesio_to_image_tensor(bio, mode=mode) # type: ignore[arg-type] + + await download_url_to_bytesio(url, timeout, auth_kwargs, dest=dest) + + if isinstance(dest, BytesIO): + with contextlib.suppress(Exception): + dest.seek(0) + return bytesio_to_image_tensor(dest, mode=mode) + + if hasattr(dest, "read") and hasattr(dest, "seek"): + try: + with contextlib.suppress(Exception): + dest.flush() + dest.seek(0) + data = dest.read() + return bytesio_to_image_tensor(BytesIO(data), mode=mode) + except Exception: + pass + + if isinstance(dest, (str, Path)) or getattr(dest, "name", None): + path_str = str(dest if isinstance(dest, (str, Path)) else getattr(dest, "name")) + with open(path_str, "rb") as f: + return bytesio_to_image_tensor(BytesIO(f.read()), mode=mode) + + raise ValueError( + "Destination is not readable and no path is available to decode the image. " + "Pass dest=None to decode from memory, or provide a readable handle / path." + ) + + +def _generate_operation_id(method: str, url: str, attempt: int) -> str: + try: + parsed = urlparse(url) + slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_") + except Exception: + slug = "download" + return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" + + +async def _sleep_with_cancel(seconds: float) -> None: + """Sleep in 1s slices while checking for interruption.""" + end = time.monotonic() + seconds + while True: + if _is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + now = time.monotonic() + if now >= end: + return + await asyncio.sleep(min(1.0, end - now)) diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py new file mode 100644 index 000000000000..457a6f11ecf4 --- /dev/null +++ b/comfy_api_nodes/util/upload_helpers.py @@ -0,0 +1,272 @@ +import uuid +import asyncio +import contextlib +from io import BytesIO +import logging +import time +from typing import Optional, Union + +import aiohttp +import torch +from pydantic import BaseModel, Field + +from comfy_api.latest import IO +from urllib.parse import urlparse +from .api_client import ( + ApiEndpoint, + sync_op_pydantic, + _display_time_progress, + _diagnose_connectivity, +) + +from comfy_api_nodes.apis import request_logger +from comfy_api_nodes.apinode_utils import tensor_to_bytesio +from ._helpers import _sleep_with_interrupt, _is_processing_interrupted +from .common_exceptions import ProcessingInterrupted, LocalNetworkError, ApiServerError + + +class UploadRequest(BaseModel): + file_name: str = Field(..., description="Filename to upload") + content_type: Optional[str] = Field( + None, + description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", + ) + + +class UploadResponse(BaseModel): + download_url: str = Field(..., description="URL to GET uploaded file") + upload_url: str = Field(..., description="URL to PUT file to upload") + + +async def upload_images_to_comfyapi( + cls: type[IO.ComfyNode], + image: torch.Tensor, + *, + max_images: int = 8, + mime_type: Optional[str] = None, + wait_label: Optional[str] = "Uploading", +) -> list[str]: + """ + Uploads images to ComfyUI API and returns download URLs. + To upload multiple images, stack them in the batch dimension first. + """ + # if batch, try to upload each file if max_images is greater than 0 + download_urls: list[str] = [] + is_batch = len(image.shape) > 3 + batch_len = image.shape[0] if is_batch else 1 + + for idx in range(min(batch_len, max_images)): + tensor = image[idx] if is_batch else image + img_io = tensor_to_bytesio(tensor, mime_type=mime_type) + url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label) + download_urls.append(url) + return download_urls + + +async def upload_file_to_comfyapi( + cls: type[IO.ComfyNode], + file_bytes_io: BytesIO, + filename: str, + upload_mime_type: Optional[str], + wait_label: Optional[str] = "Uploading", +) -> str: + """Uploads a single file to ComfyUI API and returns its download URL.""" + if upload_mime_type is None: + request_object = UploadRequest(file_name=filename) + else: + request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) + create_resp = await sync_op_pydantic( + cls, + endpoint=ApiEndpoint(path="/customers/storage", method="POST"), + data=request_object, + response_model=UploadResponse, + final_label_on_success=None, + monitor_progress=False, + ) + await upload_file( + cls, create_resp.upload_url, + file_bytes_io, + content_type=upload_mime_type, + wait_label=wait_label, + ) + return create_resp.download_url + + +async def upload_file( + cls: type[IO.ComfyNode], + upload_url: str, + file: Union[BytesIO, str], + *, + content_type: Optional[str] = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: Optional[str] = None, +) -> None: + """ + Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption. + + Args: + cls: Node class (provides auth context + UI progress hooks). + upload_url: Pre-signed PUT URL. + file: BytesIO or path string. + content_type: Explicit MIME type. If None, we *suppress* Content-Type. + max_retries: Maximum retry attempts. + retry_delay: Initial delay in seconds. + retry_backoff: Exponential backoff factor. + wait_label: Progress label shown in Comfy UI. + + Raises: + ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception + """ + if isinstance(file, BytesIO): + with contextlib.suppress(Exception): + file.seek(0) + data = file.read() + elif isinstance(file, str): + with open(file, "rb") as f: + data = f.read() + else: + raise ValueError("file must be a BytesIO or a filesystem path string") + + headers: dict[str, str] = {} + skip_auto_headers: set[str] = set() + if content_type: + headers["Content-Type"] = content_type + else: + skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request + + attempt = 0 + delay = retry_delay + start_ts = time.monotonic() + op_uuid = uuid.uuid4().hex[:8] + while True: + attempt += 1 + operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid) + timeout = aiohttp.ClientTimeout(total=None) + stop_evt = asyncio.Event() + + async def _monitor(): + try: + while not stop_evt.is_set(): + if _is_processing_interrupted(): + return + if wait_label: + _display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None) + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return + + monitor_task = asyncio.create_task(_monitor()) + sess: Optional[aiohttp.ClientSession] = None + try: + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers or None, + request_params=None, + request_data=f"[File data {len(data)} bytes]", + ) + except Exception as e: + logging.debug("[DEBUG] upload request logging failed: %s", e) + + sess = aiohttp.ClientSession(timeout=timeout) + req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers) + req_task = asyncio.create_task(req) + + done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task in done and req_task in pending: + req_task.cancel() + raise ProcessingInterrupted("Upload cancelled") + + resp = await req_task + async with resp: + if resp.status >= 400: + with contextlib.suppress(Exception): + try: + body = await resp.json() + except Exception: + body = await resp.text() + msg = f"Upload failed with status {resp.status}" + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=msg, + ) + if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries: + await _sleep_with_interrupt( + delay, + cls, + wait_label, + start_ts, + None, + display_callback=_display_time_progress if wait_label else None, + ) + delay *= retry_backoff + continue + raise Exception(f"Failed to upload (HTTP {resp.status}).") + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content="File uploaded successfully.", + ) + except Exception as e: + logging.debug("[DEBUG] upload response logging failed: %s", e) + return + except (aiohttp.ClientError, asyncio.TimeoutError) as e: + if attempt <= max_retries: + with contextlib.suppress(Exception): + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers or None, + request_data=f"[File data {len(data)} bytes]", + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + await _sleep_with_interrupt( + delay, + cls, + wait_label, + start_ts, + None, + display_callback=_display_time_progress if wait_label else None, + ) + delay *= retry_backoff + continue + + diag = await _diagnose_connectivity() + if diag.get("is_local_issue"): + raise LocalNetworkError( + "Unable to connect to the network. Please check your internet connection and try again." + ) from e + raise ApiServerError("The API service appears unreachable at this time.") from e + finally: + stop_evt.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if sess: + with contextlib.suppress(Exception): + await sess.close() + + +def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str: + try: + parsed = urlparse(url) + slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_") + except Exception: + slug = "upload" + return f"{method}_{slug}_{op_uuid}_try{attempt}" diff --git a/pyproject.toml b/pyproject.toml index 653604e24f0e..fcbcb3dd919e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,8 @@ messages_control.disable = [ "too-many-branches", "too-many-locals", "too-many-arguments", + "too-many-return-statements", + "too-many-nested-blocks", "duplicate-code", "abstract-method", "superfluous-parens", From 01d9369018d75d3f6b5708230f9a7abf5136d484 Mon Sep 17 00:00:00 2001 From: Alexander Piskun Date: Thu, 16 Oct 2025 23:14:55 +0300 Subject: [PATCH 3/8] feat(api-nodes): implement new API client for V3 nodes --- comfy_api_nodes/apinode_utils.py | 435 +----------------- comfy_api_nodes/nodes_bfl.py | 3 +- comfy_api_nodes/nodes_bytedance.py | 90 ++-- comfy_api_nodes/nodes_gemini.py | 5 +- comfy_api_nodes/nodes_kling.py | 350 ++++---------- comfy_api_nodes/nodes_luma.py | 2 +- comfy_api_nodes/nodes_minimax.py | 2 +- comfy_api_nodes/nodes_moonvalley.py | 355 +++----------- comfy_api_nodes/nodes_openai.py | 4 +- comfy_api_nodes/nodes_pika.py | 6 +- comfy_api_nodes/nodes_pixverse.py | 13 +- comfy_api_nodes/nodes_recraft.py | 4 +- comfy_api_nodes/nodes_runway.py | 156 ++----- comfy_api_nodes/nodes_sora.py | 74 +-- comfy_api_nodes/nodes_stability.py | 8 +- comfy_api_nodes/nodes_veo2.py | 5 +- comfy_api_nodes/nodes_vidu.py | 114 ++--- comfy_api_nodes/nodes_wan.py | 8 +- comfy_api_nodes/util/__init__.py | 72 ++- comfy_api_nodes/util/_helpers.py | 29 +- .../util/{api_client.py => client.py} | 107 +++-- comfy_api_nodes/util/conversions.py | 386 +++++++++++++++- comfy_api_nodes/util/download_helpers.py | 236 ++++------ comfy_api_nodes/util/storage_helpers.py | 272 ----------- comfy_api_nodes/util/upload_helpers.py | 94 +++- comfy_api_nodes/util/validation_utils.py | 58 ++- pyproject.toml | 9 + 27 files changed, 1056 insertions(+), 1841 deletions(-) rename comfy_api_nodes/util/{api_client.py => client.py} (93%) delete mode 100644 comfy_api_nodes/util/storage_helpers.py diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index bc3d2d07e6a5..e3d2820592cd 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -1,15 +1,10 @@ from __future__ import annotations import aiohttp -import io -import logging import mimetypes -import os from typing import Optional, Union from comfy.utils import common_upscale -from comfy_api.input_impl import VideoFromFile from comfy_api.util import VideoContainer, VideoCodec from comfy_api.input.video_types import VideoInput -from comfy_api.input.basic_types import AudioInput from comfy_api_nodes.apis.client import ( ApiClient, ApiEndpoint, @@ -26,43 +21,8 @@ import torch import math import base64 -import uuid +from .util import tensor_to_bytesio, bytesio_to_image_tensor from io import BytesIO -import av - - -async def download_url_to_video_output( - video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None -) -> VideoFromFile: - """Downloads a video from a URL and returns a `VIDEO` output. - - Args: - video_url: The URL of the video to download. - - Returns: - A Comfy node `VIDEO` output. - """ - video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs) - if video_io is None: - error_msg = f"Failed to download video from {video_url}" - logging.error(error_msg) - raise ValueError(error_msg) - return VideoFromFile(video_io) - - -def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: - """Downscale input image tensor to roughly the specified total pixels.""" - samples = image.movedim(-1, 1) - total = int(total_pixels) - scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) - if scale_by >= 1: - return image - width = round(samples.shape[3] * scale_by) - height = round(samples.shape[2] * scale_by) - - s = common_upscale(samples, width, height, "lanczos", "disabled") - s = s.movedim(1, -1) - return s async def validate_and_cast_response( @@ -162,11 +122,6 @@ def validate_aspect_ratio( return aspect_ratio -def mimetype_to_extension(mime_type: str) -> str: - """Converts a MIME type to a file extension.""" - return mime_type.split("/")[-1].lower() - - async def download_url_to_bytesio( url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None ) -> BytesIO: @@ -195,136 +150,11 @@ async def download_url_to_bytesio( return BytesIO(await resp.read()) -def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: - """Converts image data from BytesIO to a torch.Tensor. - - Args: - image_bytesio: BytesIO object containing the image data. - mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA"). - - Returns: - A torch.Tensor representing the image (1, H, W, C). - - Raises: - PIL.UnidentifiedImageError: If the image data cannot be identified. - ValueError: If the specified mode is invalid. - """ - image = Image.open(image_bytesio) - image = image.convert(mode) - image_array = np.array(image).astype(np.float32) / 255.0 - return torch.from_numpy(image_array).unsqueeze(0) - - -async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor: - """Downloads an image from a URL and returns a [B, H, W, C] tensor.""" - image_bytesio = await download_url_to_bytesio(url, timeout) - return bytesio_to_image_tensor(image_bytesio) - - def process_image_response(response_content: bytes | str) -> torch.Tensor: """Uses content from a Response object and converts it to a torch.Tensor""" return bytesio_to_image_tensor(BytesIO(response_content)) -def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: - """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling.""" - if len(image.shape) > 3: - image = image[0] - # TODO: remove alpha if not allowed and present - input_tensor = image.cpu() - input_tensor = downscale_image_tensor( - input_tensor.unsqueeze(0), total_pixels=total_pixels - ).squeeze() - image_np = (input_tensor.numpy() * 255).astype(np.uint8) - img = Image.fromarray(image_np) - return img - - -def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: - """Converts a PIL Image to a BytesIO object.""" - if not mime_type: - mime_type = "image/png" - - img_byte_arr = io.BytesIO() - # Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG') - pil_format = mime_type.split("/")[-1].upper() - if pil_format == "JPG": - pil_format = "JPEG" - img.save(img_byte_arr, format=pil_format) - img_byte_arr.seek(0) - return img_byte_arr - - -def tensor_to_bytesio( - image: torch.Tensor, - name: Optional[str] = None, - total_pixels: int = 2048 * 2048, - mime_type: str = "image/png", -) -> BytesIO: - """Converts a torch.Tensor image to a named BytesIO object. - - Args: - image: Input torch.Tensor image. - name: Optional filename for the BytesIO object. - total_pixels: Maximum total pixels for potential downscaling. - mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). - - Returns: - Named BytesIO object containing the image data, with pointer set to the start of buffer. - """ - if not mime_type: - mime_type = "image/png" - - pil_image = _tensor_to_pil(image, total_pixels=total_pixels) - img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type) - img_binary.name = ( - f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" - ) - return img_binary - - -def tensor_to_base64_string( - image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, - mime_type: str = "image/png", -) -> str: - """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string. - - Args: - image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. - mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). - - Returns: - Base64 encoded string of the image. - """ - pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels) - img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type) - img_bytes = img_byte_arr.getvalue() - # Encode bytes to base64 string - base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8") - return base64_encoded_string - - -def tensor_to_data_uri( - image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, - mime_type: str = "image/png", -) -> str: - """Converts a tensor image to a Data URI string. - - Args: - image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. - mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp'). - - Returns: - Data URI string (e.g., 'data:image/png;base64,...'). - """ - base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type) - return f"data:{mime_type};base64,{base64_string}" - - def text_filepath_to_base64_string(filepath: str) -> str: """Converts a text file to a base64 string.""" with open(filepath, "rb") as f: @@ -392,7 +222,7 @@ def video_to_base64_string( container_format: Optional container format to use (defaults to video.container if available) codec: Optional codec to use (defaults to video.codec if available) """ - video_bytes_io = io.BytesIO() + video_bytes_io = BytesIO() # Use provided format/codec if specified, otherwise use video's own if available format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) @@ -403,214 +233,6 @@ def video_to_base64_string( return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") -async def upload_video_to_comfyapi( - video: VideoInput, - auth_kwargs: Optional[dict[str, str]] = None, - container: VideoContainer = VideoContainer.MP4, - codec: VideoCodec = VideoCodec.H264, - max_duration: Optional[int] = None, -) -> str: - """ - Uploads a single video to ComfyUI API and returns its download URL. - Uses the specified container and codec for saving the video before upload. - - Args: - video: VideoInput object (Comfy VIDEO type). - auth_kwargs: Optional authentication token(s). - container: The video container format to use (default: MP4). - codec: The video codec to use (default: H264). - max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised. - - Returns: - The download URL for the uploaded video file. - """ - if max_duration is not None: - try: - actual_duration = video.duration_seconds - if actual_duration is not None and actual_duration > max_duration: - raise ValueError( - f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)." - ) - except Exception as e: - logging.error("Error getting video duration: %s", str(e)) - raise ValueError(f"Could not verify video duration from source: {e}") from e - - upload_mime_type = f"video/{container.value.lower()}" - filename = f"uploaded_video.{container.value.lower()}" - - # Convert VideoInput to BytesIO using specified container/codec - video_bytes_io = io.BytesIO() - video.save_to(video_bytes_io, format=container, codec=codec) - video_bytes_io.seek(0) - - return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs) - - -def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: - """ - Prepares audio waveform for av library by converting to a contiguous numpy array. - - Args: - waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type. - - Returns: - Contiguous numpy array of the audio waveform. If the audio was batched, - the first item is taken. - """ - if waveform.ndim != 3 or waveform.shape[0] != 1: - raise ValueError("Expected waveform tensor shape (1, channels, samples)") - - # If batch is > 1, take first item - if waveform.shape[0] > 1: - waveform = waveform[0] - - # Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array - audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy() - if audio_data_np.dtype != np.float32: - audio_data_np = audio_data_np.astype(np.float32) - - return audio_data_np - - -def audio_ndarray_to_bytesio( - audio_data_np: np.ndarray, - sample_rate: int, - container_format: str = "mp4", - codec_name: str = "aac", -) -> BytesIO: - """ - Encodes a numpy array of audio data into a BytesIO object. - """ - audio_bytes_io = io.BytesIO() - with av.open(audio_bytes_io, mode="w", format=container_format) as output_container: - audio_stream = output_container.add_stream(codec_name, rate=sample_rate) - frame = av.AudioFrame.from_ndarray( - audio_data_np, - format="fltp", - layout="stereo" if audio_data_np.shape[0] > 1 else "mono", - ) - frame.sample_rate = sample_rate - frame.pts = 0 - - for packet in audio_stream.encode(frame): - output_container.mux(packet) - - # Flush stream - for packet in audio_stream.encode(None): - output_container.mux(packet) - - audio_bytes_io.seek(0) - return audio_bytes_io - - -async def upload_audio_to_comfyapi( - audio: AudioInput, - auth_kwargs: Optional[dict[str, str]] = None, - container_format: str = "mp4", - codec_name: str = "aac", - mime_type: str = "audio/mp4", - filename: str = "uploaded_audio.mp4", -) -> str: - """ - Uploads a single audio input to ComfyUI API and returns its download URL. - Encodes the raw waveform into the specified format before uploading. - - Args: - audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate) - auth_kwargs: Optional authentication token(s). - - Returns: - The download URL for the uploaded audio file. - """ - sample_rate: int = audio["sample_rate"] - waveform: torch.Tensor = audio["waveform"] - audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) - audio_bytes_io = audio_ndarray_to_bytesio( - audio_data_np, sample_rate, container_format, codec_name - ) - - return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs) - - -def f32_pcm(wav: torch.Tensor) -> torch.Tensor: - """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" - if wav.dtype.is_floating_point: - return wav - elif wav.dtype == torch.int16: - return wav.float() / (2 ** 15) - elif wav.dtype == torch.int32: - return wav.float() / (2 ** 31) - raise ValueError(f"Unsupported wav dtype: {wav.dtype}") - - -def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict: - """ - Decode any common audio container from bytes using PyAV and return - a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}. - """ - with av.open(io.BytesIO(audio_bytes)) as af: - if not af.streams.audio: - raise ValueError("No audio stream found in response.") - stream = af.streams.audio[0] - - in_sr = int(stream.codec_context.sample_rate) - out_sr = in_sr - - frames: list[torch.Tensor] = [] - n_channels = stream.channels or 1 - - for frame in af.decode(streams=stream.index): - arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T] - buf = torch.from_numpy(arr) - if buf.ndim == 1: - buf = buf.unsqueeze(0) # [T] -> [1, T] - elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels: - buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T] - elif buf.shape[0] != n_channels: - buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T] - frames.append(buf) - - if not frames: - raise ValueError("Decoded zero audio frames.") - - wav = torch.cat(frames, dim=1) # [C, T] - wav = f32_pcm(wav) - return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} - - -def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO: - waveform = audio["waveform"].cpu() - - output_buffer = io.BytesIO() - output_container = av.open(output_buffer, mode='w', format="mp3") - - out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"]) - out_stream.bit_rate = 320000 - - frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo') - frame.sample_rate = audio["sample_rate"] - frame.pts = 0 - output_container.mux(out_stream.encode(frame)) - output_container.mux(out_stream.encode(None)) - output_container.close() - output_buffer.seek(0) - return output_buffer - - -def audio_to_base64_string( - audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac" -) -> str: - """Converts an audio input to a base64 string.""" - sample_rate: int = audio["sample_rate"] - waveform: torch.Tensor = audio["waveform"] - audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) - audio_bytes_io = audio_ndarray_to_bytesio( - audio_data_np, sample_rate, container_format, codec_name - ) - audio_bytes = audio_bytes_io.getvalue() - return base64.b64encode(audio_bytes).decode("utf-8") - - async def upload_images_to_comfyapi( image: torch.Tensor, max_images=8, @@ -663,56 +285,3 @@ def resize_mask_to_image( if not allow_gradient: mask = (mask > 0.5).float() return mask - - -def validate_string( - string: str, - strip_whitespace=True, - field_name="prompt", - min_length=None, - max_length=None, -): - if string is None: - raise Exception(f"Field '{field_name}' cannot be empty.") - if strip_whitespace: - string = string.strip() - if min_length and len(string) < min_length: - raise Exception( - f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long." - ) - if max_length and len(string) > max_length: - raise Exception( - f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long." - ) - - -def image_tensor_pair_to_batch( - image1: torch.Tensor, image2: torch.Tensor -) -> torch.Tensor: - """ - Converts a pair of image tensors to a batch tensor. - If the images are not the same size, the smaller image is resized to - match the larger image. - """ - if image1.shape[1:] != image2.shape[1:]: - image2 = common_upscale( - image2.movedim(-1, 1), - image1.shape[2], - image1.shape[1], - "bilinear", - "center", - ).movedim(1, -1) - return torch.cat((image1, image2), dim=0) - - -def get_size(path_or_object: Union[str, io.BytesIO]) -> int: - if isinstance(path_or_object, str): - return os.path.getsize(path_or_object) - return len(path_or_object.getvalue()) - - -def validate_container_format_is_mp4(video: VideoInput) -> None: - """Validates video container format is MP4.""" - container_format = video.get_container_format() - if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: - raise ValueError(f"Only MP4 container format supported. Got: {container_format}") diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index b6cc90f05c7a..3e83eb127df7 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -21,12 +21,11 @@ SynchronousOperation, ) from comfy_api_nodes.apinode_utils import ( - downscale_image_tensor, validate_aspect_ratio, process_image_response, resize_mask_to_image, - validate_string, ) +from comfy_api_nodes.util import validate_string, downscale_image_tensor import numpy as np from PIL import Image diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index f2e3e9027b8c..534af380debb 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -2,31 +2,26 @@ import math from enum import Enum from typing import Literal, Optional, Union -from typing_extensions import override import torch from pydantic import BaseModel, Field +from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO -from comfy_api_nodes.util.validation_utils import ( - validate_image_aspect_ratio_range, - get_number_of_images, - validate_image_dimensions, -) +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.util import ( ApiEndpoint, - sync_op_pydantic, - poll_op_pydantic, - upload_images_to_comfyapi, -) -from comfy_api_nodes.apinode_utils import ( download_url_to_image_tensor, download_url_to_video_output, - validate_string, + get_number_of_images, image_tensor_pair_to_batch, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_image_aspect_ratio_range, + validate_image_dimensions, + validate_string, ) - BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" # Long-running tasks endpoints(e.g., video) @@ -43,13 +38,14 @@ class Image2ImageModelName(str, Enum): class Text2VideoModelName(str, Enum): - seedance_1_pro = "seedance-1-0-pro-250528" + seedance_1_pro = "seedance-1-0-pro-250528" seedance_1_lite = "seedance-1-0-lite-t2v-250428" class Image2VideoModelName(str, Enum): """note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757""" - seedance_1_pro = "seedance-1-0-pro-250528" + + seedance_1_pro = "seedance-1-0-pro-250528" seedance_1_lite = "seedance-1-0-lite-i2v-250428" @@ -271,7 +267,7 @@ def define_schema(cls): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the image", + tooltip='Whether to add an "AI generated" watermark to the image', optional=True, ), ], @@ -309,8 +305,7 @@ async def execute( w, h = width, height if not (512 <= w <= 2048) or not (512 <= h <= 2048): raise ValueError( - f"Custom size out of range: {w}x{h}. " - "Both width and height must be between 512 and 2048 pixels." + f"Custom size out of range: {w}x{h}. " "Both width and height must be between 512 and 2048 pixels." ) payload = Text2ImageTaskCreationRequest( @@ -321,9 +316,9 @@ async def execute( guidance_scale=guidance_scale, watermark=watermark, ) - response = await sync_op_pydantic( + response = await sync_op( cls, - endpoint=ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), data=payload, response_model=ImageTaskCreationResponse, ) @@ -380,7 +375,7 @@ def define_schema(cls): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the image", + tooltip='Whether to add an "AI generated" watermark to the image', optional=True, ), ], @@ -418,9 +413,9 @@ async def execute( guidance_scale=guidance_scale, watermark=watermark, ) - response = await sync_op_pydantic( + response = await sync_op( cls, - endpoint=ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), data=payload, response_model=ImageTaskCreationResponse, ) @@ -451,7 +446,7 @@ def define_schema(cls): IO.Image.Input( "image", tooltip="Input image(s) for image-to-image generation. " - "List of 1-10 images for single or multi-reference generation.", + "List of 1-10 images for single or multi-reference generation.", optional=True, ), IO.Combo.Input( @@ -481,9 +476,9 @@ def define_schema(cls): "sequential_image_generation", options=["disabled", "auto"], tooltip="Group image generation mode. " - "'disabled' generates a single image. " - "'auto' lets the model decide whether to generate multiple related images " - "(e.g., story scenes, character variations).", + "'disabled' generates a single image. " + "'auto' lets the model decide whether to generate multiple related images " + "(e.g., story scenes, character variations).", optional=True, ), IO.Int.Input( @@ -494,7 +489,7 @@ def define_schema(cls): step=1, display_mode=IO.NumberDisplay.number, tooltip="Maximum number of images to generate when sequential_image_generation='auto'. " - "Total images (input + generated) cannot exceed 15.", + "Total images (input + generated) cannot exceed 15.", optional=True, ), IO.Int.Input( @@ -511,7 +506,7 @@ def define_schema(cls): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the image.", + tooltip='Whether to add an "AI generated" watermark to the image.', optional=True, ), IO.Boolean.Input( @@ -558,8 +553,7 @@ async def execute( w, h = width, height if not (1024 <= w <= 4096) or not (1024 <= h <= 4096): raise ValueError( - f"Custom size out of range: {w}x{h}. " - "Both width and height must be between 1024 and 4096 pixels." + f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels." ) n_input_images = get_number_of_images(image) if image is not None else 0 if n_input_images > 10: @@ -578,9 +572,9 @@ async def execute( max_images=n_input_images, mime_type="image/png", ) - response = await sync_op_pydantic( + response = await sync_op( cls, - endpoint=ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), response_model=ImageTaskCreationResponse, data=Seedream4TaskCreationRequest( model=model, @@ -656,13 +650,13 @@ def define_schema(cls): "camera_fixed", default=False, tooltip="Specifies whether to fix the camera. The platform appends an instruction " - "to fix the camera to your prompt, but does not guarantee the actual effect.", + "to fix the camera to your prompt, but does not guarantee the actual effect.", optional=True, ), IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the video.", + tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), ], @@ -767,13 +761,13 @@ def define_schema(cls): "camera_fixed", default=False, tooltip="Specifies whether to fix the camera. The platform appends an instruction " - "to fix the camera to your prompt, but does not guarantee the actual effect.", + "to fix the camera to your prompt, but does not guarantee the actual effect.", optional=True, ), IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the video.", + tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), ], @@ -890,13 +884,13 @@ def define_schema(cls): "camera_fixed", default=False, tooltip="Specifies whether to fix the camera. The platform appends an instruction " - "to fix the camera to your prompt, but does not guarantee the actual effect.", + "to fix the camera to your prompt, but does not guarantee the actual effect.", optional=True, ), IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the video.", + tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), ], @@ -1020,7 +1014,7 @@ def define_schema(cls): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the video.", + tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), ], @@ -1064,7 +1058,7 @@ async def execute( ) x = [ TaskTextContent(text=prompt), - *[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls] + *[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls], ] return await process_video_task( cls, @@ -1078,18 +1072,15 @@ async def process_video_task( payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], estimated_duration: Optional[int], ) -> IO.NodeOutput: - initial_response = await sync_op_pydantic( + initial_response = await sync_op( cls, - endpoint=ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), + ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), data=payload, response_model=TaskCreationResponse, ) - response = await poll_op_pydantic( + response = await poll_op( cls, - poll_endpoint=ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"), - completed_statuses=["succeeded"], - failed_statuses=["cancelled", "failed"], - queued_states=["queued"], + ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"), status_extractor=lambda r: r.status, estimated_duration=estimated_duration, response_model=TaskStatusResponse, @@ -1118,5 +1109,6 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: ByteDanceImageReferenceNode, ] + async def comfy_entrypoint() -> ByteDanceExtension: return ByteDanceExtension() diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index c1941cbe929f..ca11b67ed192 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -33,12 +33,9 @@ SynchronousOperation, ) from comfy_api_nodes.apinode_utils import ( - validate_string, - audio_to_base64_string, video_to_base64_string, - tensor_to_base64_string, - bytesio_to_image_tensor, ) +from comfy_api_nodes.util import validate_string, tensor_to_base64_string, bytesio_to_image_tensor, audio_to_base64_string from comfy_api.util import VideoContainer, VideoCodec diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 67c8307c55ff..eea65c9acf97 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -5,8 +5,7 @@ """ from __future__ import annotations -from typing import Optional, TypeVar, Any -from collections.abc import Callable +from typing import Optional, TypeVar import math import logging @@ -15,7 +14,6 @@ import torch from comfy_api_nodes.apis import ( - KlingTaskStatus, KlingCameraControl, KlingCameraConfig, KlingCameraControlType, @@ -52,26 +50,20 @@ KlingCharacterEffectModelName, KlingSingleImageEffectModelName, ) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( - tensor_to_base64_string, - download_url_to_video_output, - upload_video_to_comfyapi, - upload_audio_to_comfyapi, - download_url_to_image_tensor, - validate_string, -) -from comfy_api_nodes.util.validation_utils import ( +from comfy_api_nodes.util import ( validate_image_dimensions, validate_image_aspect_ratio, validate_video_dimensions, validate_video_duration, + tensor_to_base64_string, + validate_string, + upload_audio_to_comfyapi, + download_url_to_image_tensor, + upload_video_to_comfyapi, + download_url_to_video_output, + sync_op, + ApiEndpoint, + poll_op, ) from comfy_api.input_impl import VideoFromFile from comfy_api.input.basic_types import AudioInput @@ -214,34 +206,6 @@ } -async def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, R], - result_url_extractor: Optional[Callable[[R], str]] = None, - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> R: - """Polls the Kling API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[ - KlingTaskStatus.succeed.value, - ], - failed_statuses=[KlingTaskStatus.failed.value], - status_extractor=lambda response: ( - response.data.task_status.value - if response.data and response.data.task_status - else None - ), - auth_kwargs=auth_kwargs, - result_url_extractor=result_url_extractor, - estimated_duration=estimated_duration, - node_id=node_id, - poll_interval=16.0, - max_poll_attempts=256, - ).execute() - - def is_valid_camera_control_configs(configs: list[float]) -> bool: """Verifies that at least one camera control configuration is non-zero.""" return any(not math.isclose(value, 0.0) for value in configs) @@ -377,8 +341,7 @@ async def image_result_to_node_output( async def execute_text2video( - auth_kwargs: dict[str, str], - node_id: str, + cls: type[IO.ComfyNode], prompt: str, negative_prompt: str, cfg_scale: float, @@ -389,14 +352,11 @@ async def execute_text2video( camera_control: Optional[KlingCameraControl] = None, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_TEXT_TO_VIDEO, - method=HttpMethod.POST, - request_model=KlingText2VideoRequest, - response_model=KlingText2VideoResponse, - ), - request=KlingText2VideoRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"), + response_model=KlingText2VideoResponse, + data=KlingText2VideoRequest( prompt=prompt if prompt else None, negative_prompt=negative_prompt if negative_prompt else None, duration=KlingVideoGenDuration(duration), @@ -406,24 +366,17 @@ async def execute_text2video( aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), camera_control=camera_control, ), - auth_kwargs=auth_kwargs, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_TEXT_TO_VIDEO}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingText2VideoResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_TEXT_TO_VIDEO}/{task_id}"), + response_model=KlingText2VideoResponse, estimated_duration=AVERAGE_DURATION_T2V, - node_id=node_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -432,8 +385,7 @@ async def execute_text2video( async def execute_image2video( - auth_kwargs: dict[str, str], - node_id: str, + cls: type[IO.ComfyNode], start_frame: torch.Tensor, prompt: str, negative_prompt: str, @@ -455,14 +407,11 @@ async def execute_image2video( if model_mode == "std" and model_name == KlingVideoGenModelName.kling_v2_5_turbo.value: model_mode = "pro" # October 5: currently "std" mode is not supported for this model - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_TO_VIDEO, - method=HttpMethod.POST, - request_model=KlingImage2VideoRequest, - response_model=KlingImage2VideoResponse, - ), - request=KlingImage2VideoRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), + response_model=KlingImage2VideoResponse, + data=KlingImage2VideoRequest( model_name=KlingVideoGenModelName(model_name), image=tensor_to_base64_string(start_frame), image_tail=( @@ -477,24 +426,17 @@ async def execute_image2video( duration=KlingVideoGenDuration(duration), camera_control=camera_control, ), - auth_kwargs=auth_kwargs, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}", - method=HttpMethod.GET, - request_model=KlingImage2VideoRequest, - response_model=KlingImage2VideoResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"), + response_model=KlingImage2VideoResponse, estimated_duration=AVERAGE_DURATION_I2V, - node_id=node_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -503,8 +445,7 @@ async def execute_image2video( async def execute_video_effect( - auth_kwargs: dict[str, str], - node_id: str, + cls: type[IO.ComfyNode], dual_character: bool, effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene, model_name: str, @@ -530,35 +471,25 @@ async def execute_video_effect( duration=duration, ) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_VIDEO_EFFECTS, - method=HttpMethod.POST, - request_model=KlingVideoEffectsRequest, - response_model=KlingVideoEffectsResponse, - ), - request=KlingVideoEffectsRequest( + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_VIDEO_EFFECTS, method="POST"), + response_model=KlingVideoEffectsResponse, + data=KlingVideoEffectsRequest( effect_scene=effect_scene, input=request_input_field, ), - auth_kwargs=auth_kwargs, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_VIDEO_EFFECTS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVideoEffectsResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIDEO_EFFECTS}/{task_id}"), + response_model=KlingVideoEffectsResponse, estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS, - node_id=node_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -567,8 +498,7 @@ async def execute_video_effect( async def execute_lipsync( - auth_kwargs: dict[str, str], - node_id: str, + cls: type[IO.ComfyNode], video: VideoInput, audio: Optional[AudioInput] = None, voice_language: Optional[str] = None, @@ -583,24 +513,21 @@ async def execute_lipsync( validate_video_duration(video, 2, 10) # Upload video to Comfy API and get download URL - video_url = await upload_video_to_comfyapi(video, auth_kwargs=auth_kwargs) + video_url = await upload_video_to_comfyapi(cls, video) logging.info("Uploaded video to Comfy API. URL: %s", video_url) # Upload the audio file to Comfy API and get download URL if audio: - audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=auth_kwargs) + audio_url = await upload_audio_to_comfyapi(cls, audio) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) else: audio_url = None - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_LIP_SYNC, - method=HttpMethod.POST, - request_model=KlingLipSyncRequest, - response_model=KlingLipSyncResponse, - ), - request=KlingLipSyncRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(PATH_LIP_SYNC, "POST"), + response_model=KlingLipSyncResponse, + data=KlingLipSyncRequest( input=KlingLipSyncInputObject( video_url=video_url, mode=model_mode, @@ -612,24 +539,17 @@ async def execute_lipsync( voice_id=voice_id, ), ), - auth_kwargs=auth_kwargs, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_LIP_SYNC}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingLipSyncResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_LIP_SYNC}/{task_id}"), + response_model=KlingLipSyncResponse, estimated_duration=AVERAGE_DURATION_LIP_SYNC, - node_id=node_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -807,11 +727,7 @@ async def execute( ) -> IO.NodeOutput: model_mode, duration, model_name = MODE_TEXT2VIDEO[mode] return await execute_text2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt=prompt, negative_prompt=negative_prompt, cfg_scale=cfg_scale, @@ -872,11 +788,7 @@ async def execute( camera_control: Optional[KlingCameraControl] = None, ) -> IO.NodeOutput: return await execute_text2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, model_name=KlingVideoGenModelName.kling_v1, cfg_scale=cfg_scale, model_mode=KlingVideoGenMode.std, @@ -944,11 +856,7 @@ async def execute( end_frame: Optional[torch.Tensor] = None, ) -> IO.NodeOutput: return await execute_image2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, start_frame=start_frame, prompt=prompt, negative_prompt=negative_prompt, @@ -1017,11 +925,7 @@ async def execute( camera_control: KlingCameraControl, ) -> IO.NodeOutput: return await execute_image2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, model_name=KlingVideoGenModelName.kling_v1_5, start_frame=start_frame, cfg_scale=cfg_scale, @@ -1097,11 +1001,7 @@ async def execute( ) -> IO.NodeOutput: mode, duration, model_name = MODE_START_END_FRAME[mode] return await execute_image2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt=prompt, negative_prompt=negative_prompt, model_name=model_name, @@ -1162,41 +1062,27 @@ async def execute( video_id: str, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_VIDEO_EXTEND, - method=HttpMethod.POST, - request_model=KlingVideoExtendRequest, - response_model=KlingVideoExtendResponse, - ), - request=KlingVideoExtendRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_VIDEO_EXTEND, method="POST"), + response_model=KlingVideoExtendResponse, + data=KlingVideoExtendRequest( prompt=prompt if prompt else None, negative_prompt=negative_prompt if negative_prompt else None, cfg_scale=cfg_scale, video_id=video_id, ), - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth, - ApiEndpoint( - path=f"{PATH_VIDEO_EXTEND}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVideoExtendResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIDEO_EXTEND}/{task_id}"), + response_model=KlingVideoExtendResponse, estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND, - node_id=cls.hidden.unique_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -1259,11 +1145,7 @@ async def execute( duration: KlingVideoGenDuration, ) -> IO.NodeOutput: video, _, duration = await execute_video_effect( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, dual_character=True, effect_scene=effect_scene, model_name=model_name, @@ -1324,11 +1206,7 @@ async def execute( return IO.NodeOutput( *( await execute_video_effect( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, dual_character=False, effect_scene=effect_scene, model_name=model_name, @@ -1379,11 +1257,7 @@ async def execute( voice_language: str, ) -> IO.NodeOutput: return await execute_lipsync( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, video=video, audio=audio, voice_language=voice_language, @@ -1445,11 +1319,7 @@ async def execute( ) -> IO.NodeOutput: voice_id, voice_language = VOICES_CONFIG[voice] return await execute_lipsync( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, video=video, text=text, voice_language=voice_language, @@ -1496,40 +1366,26 @@ async def execute( cloth_image: torch.Tensor, model_name: KlingVirtualTryOnModelName, ) -> IO.NodeOutput: - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_VIRTUAL_TRY_ON, - method=HttpMethod.POST, - request_model=KlingVirtualTryOnRequest, - response_model=KlingVirtualTryOnResponse, - ), - request=KlingVirtualTryOnRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_VIRTUAL_TRY_ON, method="POST"), + response_model=KlingVirtualTryOnResponse, + data=KlingVirtualTryOnRequest( human_image=tensor_to_base64_string(human_image), cloth_image=tensor_to_base64_string(cloth_image), model_name=model_name, ), - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth, - ApiEndpoint( - path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVirtualTryOnResponse, - ), - result_url_extractor=get_images_urls_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}"), + response_model=KlingVirtualTryOnResponse, estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON, - node_id=cls.hidden.unique_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_image_result_response(final_response) @@ -1625,18 +1481,11 @@ async def execute( else: image = tensor_to_base64_string(image) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_GENERATIONS, - method=HttpMethod.POST, - request_model=KlingImageGenerationsRequest, - response_model=KlingImageGenerationsResponse, - ), - request=KlingImageGenerationsRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_IMAGE_GENERATIONS, method="POST"), + response_model=KlingImageGenerationsResponse, + data=KlingImageGenerationsRequest( model_name=model_name, prompt=prompt, negative_prompt=negative_prompt, @@ -1647,24 +1496,17 @@ async def execute( n=n, aspect_ratio=aspect_ratio, ), - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth, - ApiEndpoint( - path=f"{PATH_IMAGE_GENERATIONS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingImageGenerationsResponse, - ), - result_url_extractor=get_images_urls_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_IMAGE_GENERATIONS}/{task_id}"), + response_model=KlingImageGenerationsResponse, estimated_duration=AVERAGE_DURATION_IMAGE_GEN, - node_id=cls.hidden.unique_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_image_result_response(final_response) diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 610d95a77b9d..e74441e5ef5f 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -35,9 +35,9 @@ from comfy_api_nodes.apinode_utils import ( upload_images_to_comfyapi, process_image_response, - validate_string, ) from server import PromptServer +from comfy_api_nodes.util import validate_string import aiohttp import torch diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 23be1ae65ad8..e3722e79b715 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -24,8 +24,8 @@ from comfy_api_nodes.apinode_utils import ( download_url_to_bytesio, upload_images_to_comfyapi, - validate_string, ) +from comfy_api_nodes.util import validate_string from server import PromptServer diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 7566188dd86c..426875d32a16 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -1,8 +1,7 @@ import logging -from typing import Any, Callable, Optional, TypeVar +from typing import Optional import torch from typing_extensions import override -from comfy_api_nodes.util.validation_utils import validate_image_dimensions from comfy_api_nodes.apis import ( MoonvalleyTextToVideoRequest, @@ -11,24 +10,22 @@ MoonvalleyVideoToVideoRequest, MoonvalleyPromptResponse, ) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( +from comfy_api_nodes.util import ( + validate_container_format_is_mp4, + validate_image_dimensions, download_url_to_video_output, - upload_images_to_comfyapi, upload_video_to_comfyapi, - validate_container_format_is_mp4, + upload_images_to_comfyapi, + sync_op, + ApiEndpoint, + poll_op, + validate_string, + trim_video, ) from comfy_api.input import VideoInput -from comfy_api.latest import ComfyExtension, InputImpl, IO -import av -import io +from comfy_api.latest import ComfyExtension, IO + API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads" API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts" @@ -51,13 +48,6 @@ MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000 -R = TypeVar("R") - - -class MoonvalleyApiError(Exception): - """Base exception for Moonvalley API errors.""" - - pass def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool: @@ -69,64 +59,7 @@ def validate_task_creation_response(response) -> None: if not is_valid_task_creation_response(response): error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}" logging.error(error_msg) - raise MoonvalleyApiError(error_msg) - - -def get_video_from_response(response): - video = response.output_url - logging.info( - "Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video - ) - return video - - -def get_video_url_from_response(response) -> Optional[str]: - """Returns the first video url from the Moonvalley video generation task result. - Will not raise an error if the response is not valid. - """ - if response: - return str(get_video_from_response(response)) - else: - return None - - -async def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, R], - result_url_extractor: Optional[Callable[[R], str]] = None, - node_id: Optional[str] = None, -) -> R: - """Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[ - "completed", - ], - max_poll_attempts=240, # 64 minutes with 16s interval - poll_interval=16.0, - failed_statuses=["error"], - status_extractor=lambda response: ( - response.status if response and response.status else None - ), - auth_kwargs=auth_kwargs, - result_url_extractor=result_url_extractor, - node_id=node_id, - ).execute() - - -def validate_prompts( - prompt: str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH -): - """Verifies that the prompt isn't empty and that neither prompt is too long.""" - if not prompt: - raise ValueError("Positive prompt is empty") - if len(prompt) > max_length: - raise ValueError(f"Positive prompt is too long: {len(prompt)} characters") - if negative_prompt and len(negative_prompt) > max_length: - raise ValueError( - f"Negative prompt is too long: {len(negative_prompt)} characters" - ) - return True + raise RuntimeError(error_msg) def validate_video_to_video_input(video: VideoInput) -> VideoInput: @@ -188,7 +121,7 @@ def _validate_and_trim_duration(video: VideoInput) -> VideoInput: def _validate_minimum_duration(duration: float) -> None: """Ensures video is at least 5 seconds long.""" if duration < 5: - raise MoonvalleyApiError("Input video must be at least 5 seconds long.") + raise ValueError("Input video must be at least 5 seconds long.") def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: @@ -198,123 +131,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: return video -def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: - """ - Returns a new VideoInput object trimmed from the beginning to the specified duration, - using av to avoid loading entire video into memory. - - Args: - video: Input video to trim - duration_sec: Duration in seconds to keep from the beginning - - Returns: - VideoFromFile object that owns the output buffer - """ - output_buffer = io.BytesIO() - - input_container = None - output_container = None - - try: - # Get the stream source - this avoids loading entire video into memory - # when the source is already a file path - input_source = video.get_stream_source() - - # Open containers - input_container = av.open(input_source, mode="r") - output_container = av.open(output_buffer, mode="w", format="mp4") - - # Set up output streams for re-encoding - video_stream = None - audio_stream = None - - for stream in input_container.streams: - logging.info("Found stream: type=%s, class=%s", stream.type, type(stream)) - if isinstance(stream, av.VideoStream): - # Create output video stream with same parameters - video_stream = output_container.add_stream( - "h264", rate=stream.average_rate - ) - video_stream.width = stream.width - video_stream.height = stream.height - video_stream.pix_fmt = "yuv420p" - logging.info( - "Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate - ) - elif isinstance(stream, av.AudioStream): - # Create output audio stream with same parameters - audio_stream = output_container.add_stream( - "aac", rate=stream.sample_rate - ) - audio_stream.sample_rate = stream.sample_rate - audio_stream.layout = stream.layout - logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels) - - # Calculate target frame count that's divisible by 16 - fps = input_container.streams.video[0].average_rate - estimated_frames = int(duration_sec * fps) - target_frames = ( - estimated_frames // 16 - ) * 16 # Round down to nearest multiple of 16 - - if target_frames == 0: - raise ValueError("Video too short: need at least 16 frames for Moonvalley") - - frame_count = 0 - audio_frame_count = 0 - - # Decode and re-encode video frames - if video_stream: - for frame in input_container.decode(video=0): - if frame_count >= target_frames: - break - - # Re-encode frame - for packet in video_stream.encode(frame): - output_container.mux(packet) - frame_count += 1 - - # Flush encoder - for packet in video_stream.encode(): - output_container.mux(packet) - - logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames) - - # Decode and re-encode audio frames - if audio_stream: - input_container.seek(0) # Reset to beginning for audio - for frame in input_container.decode(audio=0): - if frame.time >= duration_sec: - break - - # Re-encode frame - for packet in audio_stream.encode(frame): - output_container.mux(packet) - audio_frame_count += 1 - - # Flush encoder - for packet in audio_stream.encode(): - output_container.mux(packet) - - logging.info("Encoded %s audio frames", audio_frame_count) - - # Close containers - output_container.close() - input_container.close() - - # Return as VideoFromFile using the buffer - output_buffer.seek(0) - return InputImpl.VideoFromFile(output_buffer) - - except Exception as e: - # Clean up on error - if input_container is not None: - input_container.close() - if output_container is not None: - output_container.close() - raise RuntimeError(f"Failed to trim video: {str(e)}") from e - - def parse_width_height_from_res(resolution: str): # Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict res_map = { @@ -338,19 +154,12 @@ def parse_control_parameter(value): return control_map.get(value, control_map["Motion Transfer"]) -async def get_response( - task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None -) -> MoonvalleyPromptResponse: - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{API_PROMPTS_ENDPOINT}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MoonvalleyPromptResponse, - ), - result_url_extractor=get_video_url_from_response, - node_id=node_id, +async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse: + return await poll_op( + cls, + ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"), + response_model=MoonvalleyPromptResponse, + status_extractor=lambda r: (r.status if r and r.status else None), ) @@ -444,14 +253,10 @@ async def execute( steps: int, ) -> IO.NodeOutput: validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH) - validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) width_height = parse_width_height_from_res(resolution) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - inference_params = MoonvalleyTextToVideoInferenceParams( negative_prompt=negative_prompt, steps=steps, @@ -464,33 +269,17 @@ async def execute( # Get MIME type from tensor - assuming PNG format for image tensors mime_type = "image/png" - - image_url = ( - await upload_images_to_comfyapi( - image, max_images=1, auth_kwargs=auth, mime_type=mime_type - ) - )[0] - - request = MoonvalleyTextToVideoRequest( - image_url=image_url, prompt_text=prompt, inference_params=inference_params - ) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=API_IMG2VIDEO_ENDPOINT, - method=HttpMethod.POST, - request_model=MoonvalleyTextToVideoRequest, - response_model=MoonvalleyPromptResponse, + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0] + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"), + response_model=MoonvalleyPromptResponse, + data=MoonvalleyTextToVideoRequest( + image_url=image_url, prompt_text=prompt, inference_params=inference_params ), - request=request, - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) - task_id = task_creation_response.id - - final_response = await get_response( - task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id - ) + final_response = await get_response(cls, task_creation_response.id) video = await download_url_to_video_output(final_response.output_url) return IO.NodeOutput(video) @@ -582,15 +371,10 @@ async def execute( steps=33, prompt_adherence=4.5, ) -> IO.NodeOutput: - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - validated_video = validate_video_to_video_input(video) - video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth) - - validate_prompts(prompt, negative_prompt) + video_url = await upload_video_to_comfyapi(cls, validated_video) + validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) # Only include motion_intensity for Motion Transfer control_params = {} @@ -605,35 +389,20 @@ async def execute( guidance_scale=prompt_adherence, ) - control = parse_control_parameter(control_type) - - request = MoonvalleyVideoToVideoRequest( - control_type=control, - video_url=video_url, - prompt_text=prompt, - inference_params=inference_params, - ) - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=API_VIDEO2VIDEO_ENDPOINT, - method=HttpMethod.POST, - request_model=MoonvalleyVideoToVideoRequest, - response_model=MoonvalleyPromptResponse, + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"), + response_model=MoonvalleyPromptResponse, + data=MoonvalleyVideoToVideoRequest( + control_type=parse_control_parameter(control_type), + video_url=video_url, + prompt_text=prompt, + inference_params=inference_params, ), - request=request, - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) - task_id = task_creation_response.id - - final_response = await get_response( - task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id - ) - - video = await download_url_to_video_output(final_response.output_url) - return IO.NodeOutput(video) + final_response = await get_response(cls, task_creation_response.id) + return IO.NodeOutput(await download_url_to_video_output(final_response.output_url)) class MoonvalleyTxt2VideoNode(IO.ComfyNode): @@ -720,14 +489,10 @@ async def execute( seed: int, steps: int, ) -> IO.NodeOutput: - validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) width_height = parse_width_height_from_res(resolution) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - inference_params = MoonvalleyTextToVideoInferenceParams( negative_prompt=negative_prompt, steps=steps, @@ -737,30 +502,16 @@ async def execute( width=width_height["width"], height=width_height["height"], ) - request = MoonvalleyTextToVideoRequest( - prompt_text=prompt, inference_params=inference_params - ) - init_op = SynchronousOperation( - endpoint=ApiEndpoint( - path=API_TXT2VIDEO_ENDPOINT, - method=HttpMethod.POST, - request_model=MoonvalleyTextToVideoRequest, - response_model=MoonvalleyPromptResponse, - ), - request=request, - auth_kwargs=auth, + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"), + response_model=MoonvalleyPromptResponse, + data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params), ) - task_creation_response = await init_op.execute() validate_task_creation_response(task_creation_response) - task_id = task_creation_response.id - - final_response = await get_response( - task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id - ) - - video = await download_url_to_video_output(final_response.output_url) - return IO.NodeOutput(video) + final_response = await get_response(cls, task_creation_response.id) + return IO.NodeOutput(await download_url_to_video_output(final_response.output_url)) class MoonvalleyExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index e3b81de7599e..c467e840cf65 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -43,13 +43,11 @@ ) from comfy_api_nodes.apinode_utils import ( - downscale_image_tensor, validate_and_cast_response, - validate_string, - tensor_to_base64_string, text_filepath_to_data_uri, ) from comfy_api_nodes.mapper_utils import model_field_to_node_input +from comfy_api_nodes.util import downscale_image_tensor, validate_string, tensor_to_base64_string RESPONSES_ENDPOINT = "/proxy/openai/v1/responses" diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py index 27cb0067b008..5bb406a3bb77 100644 --- a/comfy_api_nodes/nodes_pika.py +++ b/comfy_api_nodes/nodes_pika.py @@ -14,11 +14,6 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput -from comfy_api_nodes.apinode_utils import ( - download_url_to_video_output, - tensor_to_bytesio, - validate_string, -) from comfy_api_nodes.apis import pika_defs from comfy_api_nodes.apis.client import ( ApiEndpoint, @@ -27,6 +22,7 @@ PollingOperation, SynchronousOperation, ) +from comfy_api_nodes.util import validate_string, download_url_to_video_output, tensor_to_bytesio R = TypeVar("R") diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 438a7f80b1e9..b2b841be88ff 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -24,10 +24,7 @@ PollingOperation, EmptyRequest, ) -from comfy_api_nodes.apinode_utils import ( - tensor_to_bytesio, - validate_string, -) +from comfy_api_nodes.util import validate_string, tensor_to_bytesio from comfy_api.input_impl import VideoFromFile from comfy_api.latest import ComfyExtension, IO @@ -50,7 +47,6 @@ def get_video_url_from_response( async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): # first, upload image to Pixverse and get image id to use in actual generation call - files = {"image": tensor_to_bytesio(image)} operation = SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/pixverse/image/upload", @@ -59,16 +55,14 @@ async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): response_model=PixverseImageUploadResponse, ), request=EmptyRequest(), - files=files, + files={"image": tensor_to_bytesio(image)}, content_type="multipart/form-data", auth_kwargs=auth_kwargs, ) response_upload: PixverseImageUploadResponse = await operation.execute() if response_upload.Resp is None: - raise Exception( - f"PixVerse image upload request failed: '{response_upload.ErrMsg}'" - ) + raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'") return response_upload.Resp.img_id @@ -95,7 +89,6 @@ def execute(cls, template: str) -> IO.NodeOutput: template_id = pixverse_templates.get(template, None) if template_id is None: raise Exception(f"Template '{template}' is not recognized.") - # just return the integer return IO.NodeOutput(template_id) diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 8beed5675c17..8ee7e55c4e71 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -24,12 +24,10 @@ EmptyRequest, ) from comfy_api_nodes.apinode_utils import ( - bytesio_to_image_tensor, download_url_to_bytesio, - tensor_to_bytesio, resize_mask_to_image, - validate_string, ) +from comfy_api_nodes.util import validate_string, tensor_to_bytesio, bytesio_to_image_tensor from server import PromptServer import torch diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index eb03a897dece..aac69167ec9c 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -11,7 +11,7 @@ """ -from typing import Union, Optional, Any +from typing import Union, Optional from typing_extensions import override from enum import Enum @@ -33,23 +33,9 @@ ReferenceImage, RunwayTextToImageAspectRatioEnum, ) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( - upload_images_to_comfyapi, - download_url_to_video_output, - image_tensor_pair_to_batch, - validate_string, - download_url_to_image_tensor, -) +from comfy_api_nodes.util import image_tensor_pair_to_batch, validate_string, validate_image_dimensions, validate_image_aspect_ratio, upload_images_to_comfyapi, download_url_to_video_output, download_url_to_image_tensor, ApiEndpoint, sync_op, poll_op from comfy_api.input_impl import VideoFromFile from comfy_api.latest import ComfyExtension, IO -from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video" PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image" @@ -91,31 +77,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N return None -async def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, TaskStatusResponse], - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> TaskStatusResponse: - """Polls the Runway API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[ - TaskStatus.SUCCEEDED.value, - ], - failed_statuses=[ - TaskStatus.FAILED.value, - TaskStatus.CANCELLED.value, - ], - status_extractor=lambda response: response.status.value, - auth_kwargs=auth_kwargs, - result_url_extractor=get_video_url_from_task_status, - estimated_duration=estimated_duration, - node_id=node_id, - progress_extractor=extract_progress_from_task_status, - ).execute() - - def extract_progress_from_task_status( response: TaskStatusResponse, ) -> Union[float, None]: @@ -132,42 +93,40 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N async def get_response( - task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None + cls: type[IO.ComfyNode], + task_id: str, estimated_duration: Optional[int] = None ) -> TaskStatusResponse: """Poll the task status until it is finished then get the response.""" - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_GET_TASK_STATUS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), + return await poll_op( + cls, + ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"), + completed_statuses=[ + TaskStatus.SUCCEEDED.value, + ], + failed_statuses=[ + TaskStatus.FAILED.value, + TaskStatus.CANCELLED.value, + ], + response_model=TaskStatusResponse, + status_extractor=lambda r: r.status.value, estimated_duration=estimated_duration, - node_id=node_id, + progress_extractor=extract_progress_from_task_status, ) async def generate_video( + cls: type[IO.ComfyNode], request: RunwayImageToVideoRequest, - auth_kwargs: dict[str, str], - node_id: Optional[str] = None, estimated_duration: Optional[int] = None, ) -> VideoFromFile: - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_TO_VIDEO, - method=HttpMethod.POST, - request_model=RunwayImageToVideoRequest, - response_model=RunwayImageToVideoResponse, - ), - request=request, - auth_kwargs=auth_kwargs, + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), + response_model=RunwayImageToVideoResponse, + data=request, ) - initial_response = await initial_operation.execute() - - final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration) + final_response = await get_response(cls, initial_response.id, estimated_duration) if not final_response.output: raise RunwayApiError("Runway task succeeded but no video data found in response.") @@ -241,20 +200,16 @@ async def execute( validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - download_urls = await upload_images_to_comfyapi( + cls, start_frame, max_images=1, mime_type="image/png", - auth_kwargs=auth_kwargs, ) return IO.NodeOutput( await generate_video( + cls, RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -269,8 +224,6 @@ async def execute( ] ), ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, ) ) @@ -341,20 +294,16 @@ async def execute( validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - download_urls = await upload_images_to_comfyapi( + cls, start_frame, max_images=1, mime_type="image/png", - auth_kwargs=auth_kwargs, ) return IO.NodeOutput( await generate_video( + cls, RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -369,8 +318,6 @@ async def execute( ] ), ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_FLF_SECONDS, ) ) @@ -452,23 +399,19 @@ async def execute( validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) download_urls = await upload_images_to_comfyapi( + cls, stacked_input_images, max_images=2, mime_type="image/png", - auth_kwargs=auth_kwargs, ) if len(download_urls) != 2: raise RunwayApiError("Failed to upload one or more images to comfy api.") return IO.NodeOutput( await generate_video( + cls, RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -486,8 +429,6 @@ async def execute( ] ), ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_FLF_SECONDS, ) ) @@ -540,49 +481,34 @@ async def execute( ) -> IO.NodeOutput: validate_string(prompt, min_length=1) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - # Prepare reference images if provided reference_images = None if reference_image is not None: validate_image_dimensions(reference_image, max_width=7999, max_height=7999) validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0) download_urls = await upload_images_to_comfyapi( + cls, reference_image, max_images=1, mime_type="image/png", - auth_kwargs=auth_kwargs, ) reference_images = [ReferenceImage(uri=str(download_urls[0]))] - request = RunwayTextToImageRequest( - promptText=prompt, - model=Model4.gen4_image, - ratio=ratio, - referenceImages=reference_images, - ) - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_TEXT_TO_IMAGE, - method=HttpMethod.POST, - request_model=RunwayTextToImageRequest, - response_model=RunwayTextToImageResponse, + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_TEXT_TO_IMAGE, method="POST"), + response_model=RunwayTextToImageResponse, + data=RunwayTextToImageRequest( + promptText=prompt, + model=Model4.gen4_image, + ratio=ratio, + referenceImages=reference_images, ), - request=request, - auth_kwargs=auth_kwargs, ) - initial_response = await initial_operation.execute() - - # Poll for completion final_response = await get_response( + cls, initial_response.id, - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_T2I_SECONDS, ) if not final_response.output: diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index efc95486977e..92b225d4043d 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -1,23 +1,20 @@ from typing import Optional -from typing_extensions import override import torch from pydantic import BaseModel, Field -from comfy_api.latest import ComfyExtension, IO -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.util.validation_utils import get_number_of_images +from typing_extensions import override -from comfy_api_nodes.apinode_utils import ( +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.util import ( + ApiEndpoint, download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, tensor_to_bytesio, ) + class Sora2GenerationRequest(BaseModel): prompt: str = Field(...) model: str = Field(...) @@ -80,7 +77,7 @@ def define_schema(cls): control_after_generate=True, optional=True, tooltip="Seed to determine if node should re-run; " - "actual results are nondeterministic regardless of seed.", + "actual results are nondeterministic regardless of seed.", ), ], outputs=[ @@ -111,55 +108,34 @@ async def execute( if get_number_of_images(image) != 1: raise ValueError("Currently only one input image is supported.") files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")} - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - payload = Sora2GenerationRequest( - model=model, - prompt=prompt, - seconds=str(duration), - size=size, - ) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/openai/v1/videos", - method=HttpMethod.POST, - request_model=Sora2GenerationRequest, - response_model=Sora2GenerationResponse + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/openai/v1/videos", method="POST"), + data=Sora2GenerationRequest( + model=model, + prompt=prompt, + seconds=str(duration), + size=size, ), - request=payload, files=files_input, - auth_kwargs=auth, + response_model=Sora2GenerationResponse, content_type="multipart/form-data", ) - initial_response = await initial_operation.execute() if initial_response.error: - raise Exception(initial_response.error.message) + raise Exception(initial_response.error["message"]) model_time_multiplier = 1 if model == "sora-2" else 2 - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/openai/v1/videos/{initial_response.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=Sora2GenerationResponse - ), - completed_statuses=["completed"], - failed_statuses=["failed"], + await poll_op( + cls, + poll_endpoint=ApiEndpoint(path=f"/proxy/openai/v1/videos/{initial_response.id}"), + response_model=Sora2GenerationResponse, status_extractor=lambda x: x.status, - auth_kwargs=auth, poll_interval=8.0, max_poll_attempts=160, - node_id=cls.hidden.unique_id, - estimated_duration=45 * (duration / 4) * model_time_multiplier, + estimated_duration=int(45 * (duration / 4) * model_time_multiplier), ) - await poll_operation.execute() return IO.NodeOutput( - await download_url_to_video_output( - f"/proxy/openai/v1/videos/{initial_response.id}/content", - auth_kwargs=auth, - ) + await download_url_to_video_output(f"/proxy/openai/v1/videos/{initial_response.id}/content", cls=cls), ) diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 8af03cfd1247..783666ddf5fa 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -27,14 +27,14 @@ PollingOperation, EmptyRequest, ) -from comfy_api_nodes.apinode_utils import ( +from comfy_api_nodes.util import ( + validate_audio_duration, + validate_string, + audio_input_to_mp3, bytesio_to_image_tensor, tensor_to_bytesio, - validate_string, audio_bytes_to_audio_input, - audio_input_to_mp3, ) -from comfy_api_nodes.util.validation_utils import validate_audio_duration import torch import base64 diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 4ab5c518614d..de2408e38fdf 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -21,10 +21,7 @@ PollingOperation, ) -from comfy_api_nodes.apinode_utils import ( - downscale_image_tensor, - tensor_to_base64_string, -) +from comfy_api_nodes.util import downscale_image_tensor, tensor_to_base64_string AVERAGE_DURATION_VIDEO_GEN = 32 MODELS_MAP = { diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 639be4b2be66..9c4d30bc3b24 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -1,26 +1,23 @@ import logging from enum import Enum -from typing import Any, Callable, Optional, Literal, TypeVar +from typing import Optional, Literal, TypeVar from typing_extensions import override import torch from pydantic import BaseModel, Field from comfy_api.latest import ComfyExtension, IO -from comfy_api_nodes.util.validation_utils import ( +from comfy_api_nodes.util import ( validate_aspect_ratio_closeness, validate_image_dimensions, validate_image_aspect_ratio_range, get_number_of_images, -) -from comfy_api_nodes.apis.client import ( + download_url_to_video_output, + upload_images_to_comfyapi, ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, + sync_op, + poll_op, ) -from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" @@ -63,17 +60,9 @@ class TaskCreationRequest(BaseModel): images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL") -class TaskStatus(str, Enum): - created = "created" - queueing = "queueing" - processing = "processing" - success = "success" - failed = "failed" - - class TaskCreationResponse(BaseModel): task_id: str = Field(...) - state: TaskStatus = Field(...) + state: str = Field(...) created_at: str = Field(...) code: Optional[int] = Field(None, description="Error code") @@ -85,32 +74,11 @@ class TaskResult(BaseModel): class TaskStatusResponse(BaseModel): - state: TaskStatus = Field(...) + state: str = Field(...) err_code: Optional[str] = Field(None) creations: list[TaskResult] = Field(..., description="Generated results") -async def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, R], - result_url_extractor: Optional[Callable[[R], str]] = None, - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> R: - return await PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[TaskStatus.success.value], - failed_statuses=[TaskStatus.failed.value], - status_extractor=lambda response: response.state.value, - auth_kwargs=auth_kwargs, - result_url_extractor=result_url_extractor, - estimated_duration=estimated_duration, - node_id=node_id, - poll_interval=16.0, - max_poll_attempts=256, - ).execute() - - def get_video_url_from_response(response) -> Optional[str]: if response.creations: return response.creations[0].url @@ -127,37 +95,27 @@ def get_video_from_response(response) -> TaskResult: async def execute_task( + cls: type[IO.ComfyNode], vidu_endpoint: str, - auth_kwargs: Optional[dict[str, str]], payload: TaskCreationRequest, estimated_duration: int, - node_id: str, ) -> R: - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=vidu_endpoint, - method=HttpMethod.POST, - request_model=TaskCreationRequest, - response_model=TaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - if response.state == TaskStatus.failed: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=vidu_endpoint,method="POST"), + response_model=TaskCreationResponse, + data=payload, + ) + if response.state == "failed": error_msg = f"Vidu request failed. Code: {response.code}" logging.error(error_msg) raise RuntimeError(error_msg) - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=VIDU_GET_GENERATION_STATUS % response.task_id, - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - result_url_extractor=get_video_url_from_response, + return await poll_op( + cls, + ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id), + response_model=TaskStatusResponse, + status_extractor=lambda r: r.state.value, estimated_duration=estimated_duration, - node_id=node_id, ) @@ -258,11 +216,7 @@ async def execute( resolution=resolution, movement_amplitude=movement_amplitude, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id) + results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) @@ -362,17 +316,13 @@ async def execute( resolution=resolution, movement_amplitude=movement_amplitude, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } payload.images = await upload_images_to_comfyapi( + cls, image, max_images=1, mime_type="image/png", - auth_kwargs=auth, ) - results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id) + results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) @@ -484,17 +434,13 @@ async def execute( resolution=resolution, movement_amplitude=movement_amplitude, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } payload.images = await upload_images_to_comfyapi( + cls, images, max_images=7, mime_type="image/png", - auth_kwargs=auth, ) - results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id) + results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) @@ -596,15 +542,11 @@ async def execute( resolution=resolution, movement_amplitude=movement_amplitude, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } payload.images = [ - (await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0] + (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] for frame in (first_frame, end_frame) ] - results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id) + results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index b089bd907b25..61d50746b2de 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -14,14 +14,8 @@ R, T, ) -from comfy_api_nodes.util.validation_utils import get_number_of_images, validate_audio_duration +from comfy_api_nodes.util import get_number_of_images, validate_audio_duration, tensor_to_base64_string, audio_to_base64_string, download_url_to_video_output, download_url_to_image_tensor -from comfy_api_nodes.apinode_utils import ( - download_url_to_image_tensor, - download_url_to_video_output, - tensor_to_base64_string, - audio_to_base64_string, -) class Text2ImageInputField(BaseModel): prompt: str = Field(...) diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index fe3cda258d6d..c2ec391aadd4 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -1,23 +1,87 @@ -from .api_client import ApiEndpoint, sync_op_pydantic, poll_op_pydantic, sync_op, poll_op +from ._helpers import get_fs_object_size +from .client import ( + ApiEndpoint, + poll_op, + poll_op_raw, + sync_op, + sync_op_raw, +) +from .conversions import ( + audio_bytes_to_audio_input, + audio_input_to_mp3, + audio_to_base64_string, + bytesio_to_image_tensor, + downscale_image_tensor, + image_tensor_pair_to_batch, + pil_to_bytesio, + tensor_to_base64_string, + tensor_to_bytesio, + tensor_to_pil, + trim_video, +) from .download_helpers import ( download_url_to_bytesio, download_url_to_image_tensor, - bytesio_to_image_tensor, + download_url_to_video_output, ) from .upload_helpers import ( + upload_audio_to_comfyapi, upload_file_to_comfyapi, upload_images_to_comfyapi, + upload_video_to_comfyapi, +) +from .validation_utils import ( + get_number_of_images, + validate_aspect_ratio_closeness, + validate_audio_duration, + validate_container_format_is_mp4, + validate_image_aspect_ratio, + validate_image_aspect_ratio_range, + validate_image_dimensions, + validate_string, + validate_video_dimensions, + validate_video_duration, ) __all__ = [ + # API client "ApiEndpoint", "poll_op", + "poll_op_raw", "sync_op", - "poll_op_pydantic", - "sync_op_pydantic", + "sync_op_raw", + # Upload helpers + "upload_audio_to_comfyapi", "upload_file_to_comfyapi", "upload_images_to_comfyapi", + "upload_video_to_comfyapi", + # Download helpers "download_url_to_bytesio", "download_url_to_image_tensor", + "download_url_to_video_output", + # Conversions + "audio_bytes_to_audio_input", + "audio_input_to_mp3", + "audio_to_base64_string", "bytesio_to_image_tensor", + "downscale_image_tensor", + "image_tensor_pair_to_batch", + "pil_to_bytesio", + "tensor_to_base64_string", + "tensor_to_bytesio", + "tensor_to_pil", + "trim_video", + # Validation utilities + "get_number_of_images", + "validate_aspect_ratio_closeness", + "validate_audio_duration", + "validate_container_format_is_mp4", + "validate_image_aspect_ratio", + "validate_image_aspect_ratio_range", + "validate_image_dimensions", + "validate_string", + "validate_video_dimensions", + "validate_video_duration", + # Misc functions + "get_fs_object_size", ] diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index 1bf951adc48e..bf6c32d49771 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -1,25 +1,27 @@ import asyncio import contextlib +import os import time -from typing import Optional, Callable +from io import BytesIO +from typing import Callable, Optional, Union -from comfy_api.latest import IO from comfy.cli_args import args from comfy.model_management import processing_interrupted +from comfy_api.latest import IO from .common_exceptions import ProcessingInterrupted -def _is_processing_interrupted() -> bool: +def is_processing_interrupted() -> bool: """Return True if user/runtime requested interruption.""" return processing_interrupted() -def _get_node_id(node_cls: type[IO.ComfyNode]) -> str: +def get_node_id(node_cls: type[IO.ComfyNode]) -> str: return node_cls.hidden.unique_id -def _get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]: +def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]: if node_cls.hidden.auth_token_comfy_org: return {"Authorization": f"Bearer {node_cls.hidden.auth_token_comfy_org}"} if node_cls.hidden.api_key_comfy_org: @@ -27,11 +29,11 @@ def _get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]: return {} -def _default_base_url() -> str: +def default_base_url() -> str: return getattr(args, "comfy_api_base", "https://api.comfy.org") -async def _sleep_with_interrupt( +async def sleep_with_interrupt( seconds: float, node_cls: type[IO.ComfyNode], label: Optional[str] = None, @@ -47,7 +49,7 @@ async def _sleep_with_interrupt( """ end = time.monotonic() + seconds while True: - if _is_processing_interrupted(): + if is_processing_interrupted(): raise ProcessingInterrupted("Task cancelled") now = time.monotonic() if start_ts is not None and label and display_callback: @@ -56,3 +58,14 @@ async def _sleep_with_interrupt( if now >= end: break await asyncio.sleep(min(1.0, end - now)) + + +def mimetype_to_extension(mime_type: str) -> str: + """Converts a MIME type to a file extension.""" + return mime_type.split("/")[-1].lower() + + +def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int: + if isinstance(path_or_object, str): + return os.path.getsize(path_or_object) + return len(path_or_object.getvalue()) diff --git a/comfy_api_nodes/util/api_client.py b/comfy_api_nodes/util/client.py similarity index 93% rename from comfy_api_nodes/util/api_client.py rename to comfy_api_nodes/util/client.py index 95614d6b6987..184e23824451 100644 --- a/comfy_api_nodes/util/api_client.py +++ b/comfy_api_nodes/util/client.py @@ -8,26 +8,26 @@ from dataclasses import dataclass from enum import Enum from io import BytesIO -from typing import Any, Callable, Optional, Union, Type, TypeVar, Literal +from typing import Any, Callable, Literal, Optional, Type, TypeVar, Union +from urllib.parse import urljoin, urlparse import aiohttp from aiohttp.client_exceptions import ClientError, ContentTypeError -from comfy_api.latest import IO -from comfy import utils from pydantic import BaseModel -from server import PromptServer -from urllib.parse import urljoin, urlparse +from comfy import utils +from comfy_api.latest import IO from comfy_api_nodes.apis import request_logger -from .common_exceptions import ProcessingInterrupted, LocalNetworkError, ApiServerError +from server import PromptServer + from ._helpers import ( - _is_processing_interrupted, - _get_node_id, - _get_auth_header, - _default_base_url, - _sleep_with_interrupt, + default_base_url, + get_auth_header, + get_node_id, + is_processing_interrupted, + sleep_with_interrupt, ) - +from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted M = TypeVar("M", bound=BaseModel) @@ -78,9 +78,12 @@ class _PollUIState: _RETRY_STATUS = {408, 429, 500, 502, 503, 504} +COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"] +FAILED_STATUSES = ["cancelled", "failed", "error"] +QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"] -async def sync_op_pydantic( +async def sync_op( cls: type[IO.ComfyNode], endpoint: ApiEndpoint, *, @@ -99,7 +102,7 @@ async def sync_op_pydantic( progress_origin_ts: Optional[float] = None, monitor_progress: bool = True, ) -> M: - raw = await sync_op( + raw = await sync_op_raw( cls, endpoint, data=data, @@ -122,17 +125,17 @@ async def sync_op_pydantic( return _validate_or_raise(response_model, raw) -async def poll_op_pydantic( +async def poll_op( cls: type[IO.ComfyNode], - *, poll_endpoint: ApiEndpoint, + *, response_model: Type[M], status_extractor: Callable[[M], Optional[str]], progress_extractor: Optional[Callable[[M], Optional[int]]] = None, price_extractor: Optional[Callable[[M], Optional[float]]] = None, - completed_statuses: list[str], - failed_statuses: list[str], - queued_states: Optional[list[str]] = None, + completed_statuses: Optional[list[Union[str, int]]] = None, + failed_statuses: Optional[list[Union[str, int]]] = None, + queued_statuses: Optional[list[Union[str, int]]] = None, poll_interval: float = 5.0, max_poll_attempts: int = 120, timeout_per_poll: float = 120.0, @@ -143,7 +146,7 @@ async def poll_op_pydantic( cancel_endpoint: Optional[ApiEndpoint] = None, cancel_timeout: float = 10.0, ) -> M: - raw = await poll_op( + raw = await poll_op_raw( cls, poll_endpoint=poll_endpoint, status_extractor=_wrap_model_extractor(response_model, status_extractor), @@ -151,7 +154,7 @@ async def poll_op_pydantic( price_extractor=_wrap_model_extractor(response_model, price_extractor), completed_statuses=completed_statuses, failed_statuses=failed_statuses, - queued_states=queued_states, + queued_statuses=queued_statuses, poll_interval=poll_interval, max_poll_attempts=max_poll_attempts, timeout_per_poll=timeout_per_poll, @@ -167,7 +170,7 @@ async def poll_op_pydantic( return _validate_or_raise(response_model, raw) -async def sync_op( +async def sync_op_raw( cls: type[IO.ComfyNode], endpoint: ApiEndpoint, *, @@ -216,16 +219,16 @@ async def sync_op( return await _request_base(cfg, expect_binary=as_binary) -async def poll_op( +async def poll_op_raw( cls: type[IO.ComfyNode], - *, poll_endpoint: ApiEndpoint, + *, status_extractor: Callable[[dict[str, Any]], Optional[str]], progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None, price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, - completed_statuses: list[str], - failed_statuses: list[str], - queued_states: Optional[list[str]] = None, + completed_statuses: Optional[list[Union[str, int]]] = None, + failed_statuses: Optional[list[Union[str, int]]] = None, + queued_statuses: Optional[list[Union[str, int]]] = None, poll_interval: float = 5.0, max_poll_attempts: int = 120, timeout_per_poll: float = 120.0, @@ -239,9 +242,14 @@ async def poll_op( """ Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing, checks interruption every second, and calls Cancel endpoint (if provided) on interruption. + + Uses default complete, failed and queued states assumption. + Returns the final JSON response from the poll endpoint. """ - queued_states = queued_states or [] + completed_states = COMPLETED_STATUSES if completed_statuses is None else completed_statuses + failed_states = FAILED_STATUSES if failed_statuses is None else failed_statuses + queued_states = QUEUED_STATUSES if queued_statuses is None else queued_statuses started = time.monotonic() consumed_attempts = 0 # counts only non-queued polls @@ -255,7 +263,7 @@ async def _ticker(): """Emit a UI update every second while polling is in progress.""" try: while not stop_ticker.is_set(): - if _is_processing_interrupted(): + if is_processing_interrupted(): break now = time.monotonic() proc_elapsed = state.base_processing_elapsed + ( @@ -278,7 +286,7 @@ async def _ticker(): try: while consumed_attempts < max_poll_attempts: try: - resp_json = await sync_op( + resp_json = await sync_op_raw( cls, poll_endpoint, timeout=timeout_per_poll, @@ -296,7 +304,7 @@ async def _ticker(): except ProcessingInterrupted: if cancel_endpoint: with contextlib.suppress(Exception): - await sync_op( + await sync_op_raw( cls, cancel_endpoint, timeout=cancel_timeout, @@ -331,7 +339,7 @@ async def _ticker(): if is_queued: if state.active_since is not None: # If we just moved from active -> queued, close the active interval - state.base_processing_elapsed += (now_ts - state.active_since) + state.base_processing_elapsed += now_ts - state.active_since state.active_since = None else: if state.active_since is None: # If we just moved from queued -> active, open a new active interval @@ -339,9 +347,9 @@ async def _ticker(): state.is_queued = is_queued state.status_label = status or ("Queued" if is_queued else "Processing") - if status in completed_statuses: + if status in completed_states: if state.active_since is not None: - state.base_processing_elapsed += (now_ts - state.active_since) + state.base_processing_elapsed += now_ts - state.active_since state.active_since = None stop_ticker.set() with contextlib.suppress(Exception): @@ -361,17 +369,17 @@ async def _ticker(): ) return resp_json - if status in failed_statuses: + if status in failed_states: msg = f"Task failed: {json.dumps(resp_json)}" logging.error(msg) raise Exception(msg) try: - await _sleep_with_interrupt(poll_interval, cls, None, None, None) + await sleep_with_interrupt(poll_interval, cls, None, None, None) except ProcessingInterrupted: if cancel_endpoint: with contextlib.suppress(Exception): - await sync_op( + await sync_op_raw( cls, cancel_endpoint, timeout=cancel_timeout, @@ -417,7 +425,7 @@ def _display_text( if text is not None: display_lines.append(text) if display_lines: - PromptServer.instance.send_progress_text("\n".join(display_lines), _get_node_id(node_cls)) + PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls)) def _display_time_progress( @@ -456,7 +464,7 @@ async def _diagnose_connectivity() -> dict[str, bool]: results["is_local_issue"] = True return results - parsed = urlparse(_default_base_url()) + parsed = urlparse(default_base_url()) health_url = f"{parsed.scheme}://{parsed.netloc}/health" with contextlib.suppress(ClientError, asyncio.TimeoutError): async with session.get(health_url) as resp: @@ -480,7 +488,7 @@ def _join_url(base_url: str, path: str) -> str: def _merge_headers(node_cls: type[IO.ComfyNode], endpoint_headers: dict[str, str]) -> dict[str, str]: headers = {"Accept": "*/*"} - headers.update(_get_auth_header(node_cls)) + headers.update(get_auth_header(node_cls)) if endpoint_headers: headers.update(endpoint_headers) return headers @@ -558,7 +566,7 @@ def _snapshot_request_body_for_logging( async def _request_base(cfg: _RequestConfig, expect_binary: bool): """Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors.""" - url = _join_url(_default_base_url(), cfg.endpoint.path) + url = _join_url(default_base_url(), cfg.endpoint.path) method = cfg.endpoint.method params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None) @@ -566,7 +574,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): """Every second: update elapsed time and signal interruption.""" try: while not stop_evt.is_set(): - if _is_processing_interrupted(): + if is_processing_interrupted(): return if cfg.monitor_progress: _display_time_progress( @@ -700,7 +708,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): except Exception as _log_e: logging.debug("[DEBUG] response logging failed: %s", _log_e) - await _sleep_with_interrupt( + await sleep_with_interrupt( delay, cfg.node_cls, cfg.wait_label if cfg.monitor_progress else None, @@ -735,7 +743,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): now = time.monotonic() if now - last_tick >= 1.0: last_tick = now - if _is_processing_interrupted(): + if is_processing_interrupted(): raise ProcessingInterrupted("Task cancelled") if cfg.monitor_progress: _display_time_progress( @@ -790,7 +798,12 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): if attempt <= cfg.max_retries: logging.warning( "Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s", - method, url, delay, attempt, cfg.max_retries, str(e) + method, + url, + delay, + attempt, + cfg.max_retries, + str(e), ) try: request_logger.log_request_response( @@ -804,7 +817,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): ) except Exception as _log_e: logging.debug("[DEBUG] request error logging failed: %s", _log_e) - await _sleep_with_interrupt( + await sleep_with_interrupt( delay, cfg.node_cls, cfg.wait_label if cfg.monitor_progress else None, @@ -845,7 +858,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): except Exception as _log_e: logging.debug("[DEBUG] final error logging failed: %s", _log_e) raise ApiServerError( - f"The API server at {_default_base_url()} is currently unreachable. " + f"The API server at {default_base_url()} is currently unreachable. " f"The service may be experiencing issues." ) from e finally: diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 207fe0ef25b6..3cf81459e173 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -1,8 +1,21 @@ +import base64 +import logging +import math +import uuid from io import BytesIO +from typing import Optional +import av import numpy as np -from PIL import Image import torch +from PIL import Image + +from comfy.utils import common_upscale +from comfy_api.input import VideoInput +from comfy_api.input.basic_types import AudioInput +from comfy_api.latest import InputImpl + +from ._helpers import mimetype_to_extension def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: @@ -23,3 +36,374 @@ def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch image = image.convert(mode) image_array = np.array(image).astype(np.float32) / 255.0 return torch.from_numpy(image_array).unsqueeze(0) + + +def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Converts a pair of image tensors to a batch tensor. + If the images are not the same size, the smaller image is resized to + match the larger image. + """ + if image1.shape[1:] != image2.shape[1:]: + image2 = common_upscale( + image2.movedim(-1, 1), + image1.shape[2], + image1.shape[1], + "bilinear", + "center", + ).movedim(1, -1) + return torch.cat((image1, image2), dim=0) + + +def tensor_to_bytesio( + image: torch.Tensor, + name: Optional[str] = None, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> BytesIO: + """Converts a torch.Tensor image to a named BytesIO object. + + Args: + image: Input torch.Tensor image. + name: Optional filename for the BytesIO object. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). + + Returns: + Named BytesIO object containing the image data, with pointer set to the start of buffer. + """ + if not mime_type: + mime_type = "image/png" + + pil_image = tensor_to_pil(image, total_pixels=total_pixels) + img_binary = pil_to_bytesio(pil_image, mime_type=mime_type) + img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" + return img_binary + + +def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: + """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling.""" + if len(image.shape) > 3: + image = image[0] + # TODO: remove alpha if not allowed and present + input_tensor = image.cpu() + input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze() + image_np = (input_tensor.numpy() * 255).astype(np.uint8) + img = Image.fromarray(image_np) + return img + + +def tensor_to_base64_string( + image_tensor: torch.Tensor, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> str: + """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string. + + Args: + image_tensor: Input torch.Tensor image. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). + + Returns: + Base64 encoded string of the image. + """ + pil_image = tensor_to_pil(image_tensor, total_pixels=total_pixels) + img_byte_arr = pil_to_bytesio(pil_image, mime_type=mime_type) + img_bytes = img_byte_arr.getvalue() + # Encode bytes to base64 string + base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8") + return base64_encoded_string + + +def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: + """Converts a PIL Image to a BytesIO object.""" + if not mime_type: + mime_type = "image/png" + + img_byte_arr = BytesIO() + # Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG') + pil_format = mime_type.split("/")[-1].upper() + if pil_format == "JPG": + pil_format = "JPEG" + img.save(img_byte_arr, format=pil_format) + img_byte_arr.seek(0) + return img_byte_arr + + +def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: + """Downscale input image tensor to roughly the specified total pixels.""" + samples = image.movedim(-1, 1) + total = int(total_pixels) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + if scale_by >= 1: + return image + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = common_upscale(samples, width, height, "lanczos", "disabled") + s = s.movedim(1, -1) + return s + + +def tensor_to_data_uri( + image_tensor: torch.Tensor, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> str: + """Converts a tensor image to a Data URI string. + + Args: + image_tensor: Input torch.Tensor image. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp'). + + Returns: + Data URI string (e.g., 'data:image/png;base64,...'). + """ + base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type) + return f"data:{mime_type};base64,{base64_string}" + + +def audio_to_base64_string(audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac") -> str: + """Converts an audio input to a base64 string.""" + sample_rate: int = audio["sample_rate"] + waveform: torch.Tensor = audio["waveform"] + audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) + audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) + audio_bytes = audio_bytes_io.getvalue() + return base64.b64encode(audio_bytes).decode("utf-8") + + +def audio_ndarray_to_bytesio( + audio_data_np: np.ndarray, + sample_rate: int, + container_format: str = "mp4", + codec_name: str = "aac", +) -> BytesIO: + """ + Encodes a numpy array of audio data into a BytesIO object. + """ + audio_bytes_io = BytesIO() + with av.open(audio_bytes_io, mode="w", format=container_format) as output_container: + audio_stream = output_container.add_stream(codec_name, rate=sample_rate) + frame = av.AudioFrame.from_ndarray( + audio_data_np, + format="fltp", + layout="stereo" if audio_data_np.shape[0] > 1 else "mono", + ) + frame.sample_rate = sample_rate + frame.pts = 0 + + for packet in audio_stream.encode(frame): + output_container.mux(packet) + + # Flush stream + for packet in audio_stream.encode(None): + output_container.mux(packet) + + audio_bytes_io.seek(0) + return audio_bytes_io + + +def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: + """ + Prepares audio waveform for av library by converting to a contiguous numpy array. + + Args: + waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type. + + Returns: + Contiguous numpy array of the audio waveform. If the audio was batched, + the first item is taken. + """ + if waveform.ndim != 3 or waveform.shape[0] != 1: + raise ValueError("Expected waveform tensor shape (1, channels, samples)") + + # If batch is > 1, take first item + if waveform.shape[0] > 1: + waveform = waveform[0] + + # Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array + audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy() + if audio_data_np.dtype != np.float32: + audio_data_np = audio_data_np.astype(np.float32) + + return audio_data_np + + +def audio_input_to_mp3(audio: AudioInput) -> BytesIO: + waveform = audio["waveform"].cpu() + + output_buffer = BytesIO() + output_container = av.open(output_buffer, mode="w", format="mp3") + + out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"]) + out_stream.bit_rate = 320000 + + frame = av.AudioFrame.from_ndarray( + waveform.movedim(0, 1).reshape(1, -1).float().numpy(), + format="flt", + layout="mono" if waveform.shape[0] == 1 else "stereo", + ) + frame.sample_rate = audio["sample_rate"] + frame.pts = 0 + output_container.mux(out_stream.encode(frame)) + output_container.mux(out_stream.encode(None)) + output_container.close() + output_buffer.seek(0) + return output_buffer + + +def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: + """ + Returns a new VideoInput object trimmed from the beginning to the specified duration, + using av to avoid loading entire video into memory. + + Args: + video: Input video to trim + duration_sec: Duration in seconds to keep from the beginning + + Returns: + VideoFromFile object that owns the output buffer + """ + output_buffer = BytesIO() + input_container = None + output_container = None + + try: + # Get the stream source - this avoids loading entire video into memory + # when the source is already a file path + input_source = video.get_stream_source() + + # Open containers + input_container = av.open(input_source, mode="r") + output_container = av.open(output_buffer, mode="w", format="mp4") + + # Set up output streams for re-encoding + video_stream = None + audio_stream = None + + for stream in input_container.streams: + logging.info("Found stream: type=%s, class=%s", stream.type, type(stream)) + if isinstance(stream, av.VideoStream): + # Create output video stream with same parameters + video_stream = output_container.add_stream("h264", rate=stream.average_rate) + video_stream.width = stream.width + video_stream.height = stream.height + video_stream.pix_fmt = "yuv420p" + logging.info("Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate) + elif isinstance(stream, av.AudioStream): + # Create output audio stream with same parameters + audio_stream = output_container.add_stream("aac", rate=stream.sample_rate) + audio_stream.sample_rate = stream.sample_rate + audio_stream.layout = stream.layout + logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels) + + # Calculate target frame count that's divisible by 16 + fps = input_container.streams.video[0].average_rate + estimated_frames = int(duration_sec * fps) + target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16 + + if target_frames == 0: + raise ValueError("Video too short: need at least 16 frames for Moonvalley") + + frame_count = 0 + audio_frame_count = 0 + + # Decode and re-encode video frames + if video_stream: + for frame in input_container.decode(video=0): + if frame_count >= target_frames: + break + + # Re-encode frame + for packet in video_stream.encode(frame): + output_container.mux(packet) + frame_count += 1 + + # Flush encoder + for packet in video_stream.encode(): + output_container.mux(packet) + + logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames) + + # Decode and re-encode audio frames + if audio_stream: + input_container.seek(0) # Reset to beginning for audio + for frame in input_container.decode(audio=0): + if frame.time >= duration_sec: + break + + # Re-encode frame + for packet in audio_stream.encode(frame): + output_container.mux(packet) + audio_frame_count += 1 + + # Flush encoder + for packet in audio_stream.encode(): + output_container.mux(packet) + + logging.info("Encoded %s audio frames", audio_frame_count) + + # Close containers + output_container.close() + input_container.close() + + # Return as VideoFromFile using the buffer + output_buffer.seek(0) + return InputImpl.VideoFromFile(output_buffer) + + except Exception as e: + # Clean up on error + if input_container is not None: + input_container.close() + if output_container is not None: + output_container.close() + raise RuntimeError(f"Failed to trim video: {str(e)}") from e + + +def _f32_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" + if wav.dtype.is_floating_point: + return wav + elif wav.dtype == torch.int16: + return wav.float() / (2**15) + elif wav.dtype == torch.int32: + return wav.float() / (2**31) + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") + + +def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict: + """ + Decode any common audio container from bytes using PyAV and return + a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}. + """ + with av.open(BytesIO(audio_bytes)) as af: + if not af.streams.audio: + raise ValueError("No audio stream found in response.") + stream = af.streams.audio[0] + + in_sr = int(stream.codec_context.sample_rate) + out_sr = in_sr + + frames: list[torch.Tensor] = [] + n_channels = stream.channels or 1 + + for frame in af.decode(streams=stream.index): + arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T] + buf = torch.from_numpy(arr) + if buf.ndim == 1: + buf = buf.unsqueeze(0) # [T] -> [1, T] + elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels: + buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T] + elif buf.shape[0] != n_channels: + buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T] + frames.append(buf) + + if not frames: + raise ValueError("Decoded zero audio frames.") + + wav = torch.cat(frames, dim=1) # [C, T] + wav = _f32_pcm(wav) + return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 90e127b74433..bffe9bc2fd65 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -4,147 +4,128 @@ import time import uuid from io import BytesIO -from typing import Optional, Union, IO from pathlib import Path +from typing import IO, Optional, Union +from urllib.parse import urlparse import aiohttp import torch from aiohttp.client_exceptions import ClientError, ContentTypeError -from urllib.parse import urlparse +from comfy_api.input_impl import VideoFromFile +from comfy_api.latest import IO as ComfyIO from comfy_api_nodes.apis import request_logger -from ._helpers import _is_processing_interrupted -from .common_exceptions import ProcessingInterrupted, LocalNetworkError, ApiServerError -from .api_client import _diagnose_connectivity +from ._helpers import default_base_url, get_auth_header, is_processing_interrupted +from .client import _diagnose_connectivity +from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted from .conversions import bytesio_to_image_tensor - _RETRY_STATUS = {408, 429, 500, 502, 503, 504} async def download_url_to_bytesio( url: str, - timeout: Optional[float] = None, + dest: Optional[Union[BytesIO, IO[bytes], str, Path]], *, - dest: Optional[Union[BytesIO, IO[bytes], str, Path]] = None, + timeout: Optional[float] = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff: float = 2.0, + cls: type[ComfyIO.ComfyNode] = None, ) -> None: - """Stream-download a URL into memory or to a provided destination. + """Stream-download a URL to `dest`. + + `dest` must be one of: + - a BytesIO (rewound to 0 after write), + - a file-like object opened in binary write mode (must implement .write()), + - a filesystem path (str | pathlib.Path), which will be opened with 'wb'. + + If `url` starts with `/proxy/`, `cls` must be provided so the URL can be expanded + to an absolute URL and authentication headers can be applied. Raises: ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors) """ + if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"): + raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().") + attempt = 0 delay = retry_delay + headers = {} + if url.startswith("/proxy/"): + if cls is None: + raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.") + url = default_base_url().rstrip("/") + url + headers = get_auth_header(cls) while True: attempt += 1 op_id = _generate_operation_id("GET", url, attempt) timeout_cfg = aiohttp.ClientTimeout(total=timeout) - stop_evt = asyncio.Event() - - async def _monitor(): - try: - while not stop_evt.is_set(): - if _is_processing_interrupted(): - return - await asyncio.sleep(1.0) - except asyncio.CancelledError: - return - monitor_task: Optional[asyncio.Task] = None - sess: Optional[aiohttp.ClientSession] = None - - # Open file path if a path was provided is_path_sink = isinstance(dest, (str, Path)) fhandle = None try: - try: - request_logger.log_request_response( - operation_id=op_id, - request_method="GET", - request_url=url, - ) - except Exception as e: - logging.debug("[DEBUG] download request logging failed: %s", e) - - monitor_task = asyncio.create_task(_monitor()) - sess = aiohttp.ClientSession(timeout=timeout_cfg) - req_task = asyncio.create_task(sess.get(url)) - - done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) - - # Interruption wins the race - if monitor_task in done and req_task in pending: - req_task.cancel() - raise ProcessingInterrupted("Task cancelled") - - resp = await req_task - async with resp: - if resp.status >= 400: - # Attempt to capture body for logging (do not log huge binaries) + with contextlib.suppress(Exception): + request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url) + + async with aiohttp.ClientSession(timeout=timeout_cfg) as session: + async with session.get(url, headers=headers) as resp: + if resp.status >= 400: + with contextlib.suppress(Exception): + try: + body = await resp.json() + except (ContentTypeError, ValueError): + text = await resp.text() + body = text if len(text) <= 4096 else f"[text {len(text)} bytes]" + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=f"HTTP {resp.status}", + ) + + if resp.status in _RETRY_STATUS and attempt <= max_retries: + await _sleep_with_cancel(delay) + delay *= retry_backoff + continue + raise Exception(f"Failed to download (HTTP {resp.status}).") + + if is_path_sink: + p = Path(str(dest)) + with contextlib.suppress(Exception): + p.parent.mkdir(parents=True, exist_ok=True) + fhandle = open(p, "wb") + sink = fhandle + else: + sink = dest # BytesIO or file-like + + written = 0 + async for chunk in resp.content.iter_chunked(1024 * 1024): + sink.write(chunk) + written += len(chunk) + if is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + + if isinstance(dest, BytesIO): + with contextlib.suppress(Exception): + dest.seek(0) + with contextlib.suppress(Exception): - try: - body = await resp.json() - except (ContentTypeError, ValueError): - text = await resp.text() - body = text if len(text) <= 4096 else f"[text {len(text)} bytes]" request_logger.log_request_response( operation_id=op_id, request_method="GET", request_url=url, response_status_code=resp.status, response_headers=dict(resp.headers), - response_content=body, - error_message=f"HTTP {resp.status}", + response_content=f"[streamed {written} bytes to dest]", ) + return - if resp.status in _RETRY_STATUS and attempt <= max_retries: - await _sleep_with_cancel(delay) - delay *= retry_backoff - continue - raise Exception(f"Failed to download (HTTP {resp.status}).") - - # Prepare path sink if needed - if is_path_sink: - p = Path(str(dest)) - with contextlib.suppress(Exception): - p.parent.mkdir(parents=True, exist_ok=True) - fhandle = open(p, "wb") - sink = fhandle - else: - sink = dest # BytesIO or file-like - - # Stream body in chunks to sink with cancellation checks - written = 0 - last_tick = time.monotonic() - async for chunk in resp.content.iter_chunked(1024 * 1024): - sink.write(chunk) - written += len(chunk) - now = time.monotonic() - if now - last_tick >= 1.0: - last_tick = now - if _is_processing_interrupted(): - raise ProcessingInterrupted("Task cancelled") - - if isinstance(dest, BytesIO): - dest.seek(0) - - try: - request_logger.log_request_response( - operation_id=op_id, - request_method="GET", - request_url=url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content=f"[streamed {written} bytes to dest]", - ) - except Exception as e: - logging.debug("[DEBUG] download response logging failed: %s", e) - return except ProcessingInterrupted: logging.debug("Download was interrupted by user") raise @@ -172,57 +153,30 @@ async def _monitor(): if fhandle: fhandle.flush() fhandle.close() - stop_evt.set() - if monitor_task: - monitor_task.cancel() - with contextlib.suppress(Exception): - await monitor_task - if sess: - with contextlib.suppress(Exception): - await sess.close() async def download_url_to_image_tensor( url: str, - timeout: int = None, - auth_kwargs: Optional[dict[str, str]] = None, *, - dest: Optional[Union[BytesIO, IO[bytes], str, Path]] = None, - mode: str = "RGBA", + timeout: float = None, + cls: type[ComfyIO.ComfyNode] = None, ) -> torch.Tensor: - """ - Download image and decode to tensor. Supports streaming `dest` like util version. - """ - if dest is None: - bio = await download_url_to_bytesio(url, timeout, auth_kwargs, dest=None) - return bytesio_to_image_tensor(bio, mode=mode) # type: ignore[arg-type] - - await download_url_to_bytesio(url, timeout, auth_kwargs, dest=dest) - - if isinstance(dest, BytesIO): - with contextlib.suppress(Exception): - dest.seek(0) - return bytesio_to_image_tensor(dest, mode=mode) + """Downloads an image from a URL and returns a [B, H, W, C] tensor.""" + result = BytesIO() + await download_url_to_bytesio(url, result, timeout=timeout, cls=cls) + return bytesio_to_image_tensor(result) - if hasattr(dest, "read") and hasattr(dest, "seek"): - try: - with contextlib.suppress(Exception): - dest.flush() - dest.seek(0) - data = dest.read() - return bytesio_to_image_tensor(BytesIO(data), mode=mode) - except Exception: - pass - if isinstance(dest, (str, Path)) or getattr(dest, "name", None): - path_str = str(dest if isinstance(dest, (str, Path)) else getattr(dest, "name")) - with open(path_str, "rb") as f: - return bytesio_to_image_tensor(BytesIO(f.read()), mode=mode) - - raise ValueError( - "Destination is not readable and no path is available to decode the image. " - "Pass dest=None to decode from memory, or provide a readable handle / path." - ) +async def download_url_to_video_output( + video_url: str, + *, + timeout: float = None, + cls: type[ComfyIO.ComfyNode] = None, +) -> VideoFromFile: + """Downloads a video from a URL and returns a `VIDEO` output.""" + result = BytesIO() + await download_url_to_bytesio(video_url, result, timeout=timeout, cls=cls) + return VideoFromFile(result) def _generate_operation_id(method: str, url: str, attempt: int) -> str: @@ -238,7 +192,7 @@ async def _sleep_with_cancel(seconds: float) -> None: """Sleep in 1s slices while checking for interruption.""" end = time.monotonic() + seconds while True: - if _is_processing_interrupted(): + if is_processing_interrupted(): raise ProcessingInterrupted("Task cancelled") now = time.monotonic() if now >= end: diff --git a/comfy_api_nodes/util/storage_helpers.py b/comfy_api_nodes/util/storage_helpers.py deleted file mode 100644 index d8af624efe65..000000000000 --- a/comfy_api_nodes/util/storage_helpers.py +++ /dev/null @@ -1,272 +0,0 @@ -import uuid -import asyncio -import contextlib -from io import BytesIO -import logging -import time -from typing import Optional, Union - -import aiohttp -import torch -from pydantic import BaseModel, Field - -from comfy_api.latest import IO -from urllib.parse import urlparse -from .api_client import ( - ApiEndpoint, - sync_op_pydantic, - _display_time_progress, - _diagnose_connectivity, -) - -from comfy_api_nodes.apis import request_logger -from comfy_api_nodes.apinode_utils import tensor_to_bytesio -from ._helpers import _sleep_with_interrupt, _is_processing_interrupted -from .common_exceptions import ProcessingInterrupted, LocalNetworkError, ApiServerError - - -class UploadRequest(BaseModel): - file_name: str = Field(..., description="Filename to upload") - content_type: Optional[str] = Field( - None, - description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", - ) - - -class UploadResponse(BaseModel): - download_url: str = Field(..., description="URL to GET uploaded file") - upload_url: str = Field(..., description="URL to PUT file to upload") - - -async def upload_images_to_comfyapi( - cls: type[IO.ComfyNode], - image: torch.Tensor, - *, - max_images: int = 8, - mime_type: Optional[str] = None, - status_update: bool = True, -) -> list[str]: - """ - Uploads images to ComfyUI API and returns download URLs. - To upload multiple images, stack them in the batch dimension first. - """ - # if batch, try to upload each file if max_images is greater than 0 - download_urls: list[str] = [] - is_batch = len(image.shape) > 3 - batch_len = image.shape[0] if is_batch else 1 - - for idx in range(min(batch_len, max_images)): - tensor = image[idx] if is_batch else image - img_io = tensor_to_bytesio(tensor, mime_type=mime_type) - url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, status_update) - download_urls.append(url) - return download_urls - - -async def upload_file_to_comfyapi( - cls: type[IO.ComfyNode], - file_bytes_io: BytesIO, - filename: str, - upload_mime_type: Optional[str], - status_update: bool = True, -) -> str: - """Uploads a single file to ComfyUI API and returns its download URL.""" - if upload_mime_type is None: - request_object = UploadRequest(file_name=filename) - else: - request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) - create_resp = await sync_op_pydantic( - cls, - endpoint=ApiEndpoint(path="/customers/storage", method="POST"), - data=request_object, - response_model=UploadResponse, - final_label_on_success=None, - monitor_progress=False, - ) - await upload_file( - cls, create_resp.upload_url, - file_bytes_io, - content_type=upload_mime_type, - wait_label="Uploading" if status_update else None, - ) - return create_resp.download_url - - -async def upload_file( - cls: type[IO.ComfyNode], - upload_url: str, - file: Union[BytesIO, str], - *, - content_type: Optional[str] = None, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff: float = 2.0, - wait_label: Optional[str] = None, -) -> None: - """ - Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption. - - Args: - cls: Node class (provides auth context + UI progress hooks). - upload_url: Pre-signed PUT URL. - file: BytesIO or path string. - content_type: Explicit MIME type. If None, we *suppress* Content-Type. - max_retries: Maximum retry attempts. - retry_delay: Initial delay in seconds. - retry_backoff: Exponential backoff factor. - wait_label: Progress label shown in Comfy UI. - - Raises: - ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception - """ - if isinstance(file, BytesIO): - with contextlib.suppress(Exception): - file.seek(0) - data = file.read() - elif isinstance(file, str): - with open(file, "rb") as f: - data = f.read() - else: - raise ValueError("file must be a BytesIO or a filesystem path string") - - headers: dict[str, str] = {} - skip_auto_headers: set[str] = set() - if content_type: - headers["Content-Type"] = content_type - else: - skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request - - attempt = 0 - delay = retry_delay - start_ts = time.monotonic() - op_uuid = uuid.uuid4().hex[:8] - while True: - attempt += 1 - operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid) - timeout = aiohttp.ClientTimeout(total=None) - stop_evt = asyncio.Event() - - async def _monitor(): - try: - while not stop_evt.is_set(): - if _is_processing_interrupted(): - return - if wait_label: - _display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None) - await asyncio.sleep(1.0) - except asyncio.CancelledError: - return - - monitor_task = asyncio.create_task(_monitor()) - sess: Optional[aiohttp.ClientSession] = None - try: - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - request_headers=headers or None, - request_params=None, - request_data=f"[File data {len(data)} bytes]", - ) - except Exception as e: - logging.debug("[DEBUG] upload request logging failed: %s", e) - - sess = aiohttp.ClientSession(timeout=timeout) - req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers) - req_task = asyncio.create_task(req) - - done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) - - if monitor_task in done and req_task in pending: - req_task.cancel() - raise ProcessingInterrupted("Upload cancelled") - - resp = await req_task - async with resp: - if resp.status >= 400: - with contextlib.suppress(Exception): - try: - body = await resp.json() - except Exception: - body = await resp.text() - msg = f"Upload failed with status {resp.status}" - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content=body, - error_message=msg, - ) - if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries: - await _sleep_with_interrupt( - delay, - cls, - wait_label, - start_ts, - None, - display_callback=_display_time_progress if wait_label else None, - ) - delay *= retry_backoff - continue - raise Exception(f"Failed to upload (HTTP {resp.status}).") - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content="File uploaded successfully.", - ) - except Exception as e: - logging.debug("[DEBUG] upload response logging failed: %s", e) - return - except (aiohttp.ClientError, asyncio.TimeoutError) as e: - if attempt <= max_retries: - with contextlib.suppress(Exception): - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - request_headers=headers or None, - request_data=f"[File data {len(data)} bytes]", - error_message=f"{type(e).__name__}: {str(e)} (will retry)", - ) - await _sleep_with_interrupt( - delay, - cls, - wait_label, - start_ts, - None, - display_callback=_display_time_progress if wait_label else None, - ) - delay *= retry_backoff - continue - - diag = await _diagnose_connectivity() - if diag.get("is_local_issue"): - raise LocalNetworkError( - "Unable to connect to the network. Please check your internet connection and try again." - ) from e - raise ApiServerError("The API service appears unreachable at this time.") from e - finally: - stop_evt.set() - if monitor_task: - monitor_task.cancel() - with contextlib.suppress(Exception): - await monitor_task - if sess: - with contextlib.suppress(Exception): - await sess.close() - - -def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str: - try: - parsed = urlparse(url) - slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_") - except Exception: - slug = "upload" - return f"{method}_{slug}_{op_uuid}_try{attempt}" diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index 457a6f11ecf4..fa9d10b8e802 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -1,28 +1,35 @@ -import uuid import asyncio import contextlib -from io import BytesIO import logging import time +import uuid +from io import BytesIO from typing import Optional, Union +from urllib.parse import urlparse import aiohttp import torch from pydantic import BaseModel, Field +from comfy_api.input.basic_types import AudioInput +from comfy_api.input.video_types import VideoInput from comfy_api.latest import IO -from urllib.parse import urlparse -from .api_client import ( +from comfy_api.util import VideoCodec, VideoContainer +from comfy_api_nodes.apis import request_logger + +from ._helpers import is_processing_interrupted, sleep_with_interrupt +from .client import ( ApiEndpoint, - sync_op_pydantic, - _display_time_progress, _diagnose_connectivity, + _display_time_progress, + sync_op, +) +from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted +from .conversions import ( + audio_ndarray_to_bytesio, + audio_tensor_to_contiguous_ndarray, + tensor_to_bytesio, ) - -from comfy_api_nodes.apis import request_logger -from comfy_api_nodes.apinode_utils import tensor_to_bytesio -from ._helpers import _sleep_with_interrupt, _is_processing_interrupted -from .common_exceptions import ProcessingInterrupted, LocalNetworkError, ApiServerError class UploadRequest(BaseModel): @@ -63,6 +70,60 @@ async def upload_images_to_comfyapi( return download_urls +async def upload_audio_to_comfyapi( + cls: type[IO.ComfyNode], + audio: AudioInput, + *, + container_format: str = "mp4", + codec_name: str = "aac", + mime_type: str = "audio/mp4", + filename: str = "uploaded_audio.mp4", +) -> str: + """ + Uploads a single audio input to ComfyUI API and returns its download URL. + Encodes the raw waveform into the specified format before uploading. + """ + sample_rate: int = audio["sample_rate"] + waveform: torch.Tensor = audio["waveform"] + audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) + audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) + return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type) + + +async def upload_video_to_comfyapi( + cls: type[IO.ComfyNode], + video: VideoInput, + *, + container: VideoContainer = VideoContainer.MP4, + codec: VideoCodec = VideoCodec.H264, + max_duration: Optional[int] = None, +) -> str: + """ + Uploads a single video to ComfyUI API and returns its download URL. + Uses the specified container and codec for saving the video before upload. + """ + if max_duration is not None: + try: + actual_duration = video.duration_seconds + if actual_duration is not None and actual_duration > max_duration: + raise ValueError( + f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)." + ) + except Exception as e: + logging.error("Error getting video duration: %s", str(e)) + raise ValueError(f"Could not verify video duration from source: {e}") from e + + upload_mime_type = f"video/{container.value.lower()}" + filename = f"uploaded_video.{container.value.lower()}" + + # Convert VideoInput to BytesIO using specified container/codec + video_bytes_io = BytesIO() + video.save_to(video_bytes_io, format=container, codec=codec) + video_bytes_io.seek(0) + + return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type) + + async def upload_file_to_comfyapi( cls: type[IO.ComfyNode], file_bytes_io: BytesIO, @@ -75,7 +136,7 @@ async def upload_file_to_comfyapi( request_object = UploadRequest(file_name=filename) else: request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) - create_resp = await sync_op_pydantic( + create_resp = await sync_op( cls, endpoint=ApiEndpoint(path="/customers/storage", method="POST"), data=request_object, @@ -84,7 +145,8 @@ async def upload_file_to_comfyapi( monitor_progress=False, ) await upload_file( - cls, create_resp.upload_url, + cls, + create_resp.upload_url, file_bytes_io, content_type=upload_mime_type, wait_label=wait_label, @@ -149,7 +211,7 @@ async def upload_file( async def _monitor(): try: while not stop_evt.is_set(): - if _is_processing_interrupted(): + if is_processing_interrupted(): return if wait_label: _display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None) @@ -201,7 +263,7 @@ async def _monitor(): error_message=msg, ) if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries: - await _sleep_with_interrupt( + await sleep_with_interrupt( delay, cls, wait_label, @@ -235,7 +297,7 @@ async def _monitor(): request_data=f"[File data {len(data)} bytes]", error_message=f"{type(e).__name__}: {str(e)} (will retry)", ) - await _sleep_with_interrupt( + await sleep_with_interrupt( delay, cls, wait_label, diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py index ca913e9b3eed..22da05bc199d 100644 --- a/comfy_api_nodes/util/validation_utils.py +++ b/comfy_api_nodes/util/validation_utils.py @@ -2,6 +2,8 @@ from typing import Optional import torch + +from comfy_api.input.video_types import VideoInput from comfy_api.latest import Input @@ -28,9 +30,7 @@ def validate_image_dimensions( if max_width is not None and width > max_width: raise ValueError(f"Image width must be at most {max_width}px, got {width}px") if min_height is not None and height < min_height: - raise ValueError( - f"Image height must be at least {min_height}px, got {height}px" - ) + raise ValueError(f"Image height must be at least {min_height}px, got {height}px") if max_height is not None and height > max_height: raise ValueError(f"Image height must be at most {max_height}px, got {height}px") @@ -44,13 +44,9 @@ def validate_image_aspect_ratio( aspect_ratio = width / height if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio: - raise ValueError( - f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}" - ) + raise ValueError(f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}") if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio: - raise ValueError( - f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}" - ) + raise ValueError(f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}") def validate_image_aspect_ratio_range( @@ -58,7 +54,7 @@ def validate_image_aspect_ratio_range( min_ratio: tuple[float, float], # e.g. (1, 4) max_ratio: tuple[float, float], # e.g. (4, 1) *, - strict: bool = True, # True -> (min, max); False -> [min, max] + strict: bool = True, # True -> (min, max); False -> [min, max] ) -> float: a1, b1 = min_ratio a2, b2 = max_ratio @@ -85,7 +81,7 @@ def validate_aspect_ratio_closeness( min_rel: float, max_rel: float, *, - strict: bool = False, # True => exclusive, False => inclusive + strict: bool = False, # True => exclusive, False => inclusive ) -> None: w1, h1 = get_image_dimensions(start_img) w2, h2 = get_image_dimensions(end_img) @@ -118,9 +114,7 @@ def validate_video_dimensions( if max_width is not None and width > max_width: raise ValueError(f"Video width must be at most {max_width}px, got {width}px") if min_height is not None and height < min_height: - raise ValueError( - f"Video height must be at least {min_height}px, got {height}px" - ) + raise ValueError(f"Video height must be at least {min_height}px, got {height}px") if max_height is not None and height > max_height: raise ValueError(f"Video height must be at most {max_height}px, got {height}px") @@ -138,13 +132,9 @@ def validate_video_duration( epsilon = 0.0001 if min_duration is not None and min_duration - epsilon > duration: - raise ValueError( - f"Video duration must be at least {min_duration}s, got {duration}s" - ) + raise ValueError(f"Video duration must be at least {min_duration}s, got {duration}s") if max_duration is not None and duration > max_duration + epsilon: - raise ValueError( - f"Video duration must be at most {max_duration}s, got {duration}s" - ) + raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s") def get_number_of_images(images): @@ -165,3 +155,31 @@ def validate_audio_duration( raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s") if max_duration is not None and dur - eps > max_duration: raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s") + + +def validate_string( + string: str, + strip_whitespace=True, + field_name="prompt", + min_length=None, + max_length=None, +): + if string is None: + raise Exception(f"Field '{field_name}' cannot be empty.") + if strip_whitespace: + string = string.strip() + if min_length and len(string) < min_length: + raise Exception( + f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long." + ) + if max_length and len(string) > max_length: + raise Exception( + f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long." + ) + + +def validate_container_format_is_mp4(video: VideoInput) -> None: + """Validates video container format is MP4.""" + container_format = video.get_container_format() + if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: + raise ValueError(f"Only MP4 container format supported. Got: {container_format}") diff --git a/pyproject.toml b/pyproject.toml index fcbcb3dd919e..29023d517bb1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,3 +69,12 @@ messages_control.disable = [ "no-else-return", "unused-variable", ] + + +[tool.black] +line-length = 120 +preview = true + + +[tool.isort] +profile = "black" From b34bc7987bdf1a0eb758735388657bb6ba9d93b9 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Fri, 17 Oct 2025 19:51:45 +0300 Subject: [PATCH 4/8] converted WAN nodes to use new client; polishing --- comfy_api_nodes/nodes_moonvalley.py | 37 ++-- comfy_api_nodes/nodes_runway.py | 71 +++---- comfy_api_nodes/nodes_vidu.py | 27 +-- comfy_api_nodes/nodes_wan.py | 243 ++++++++++------------- comfy_api_nodes/util/_helpers.py | 2 +- comfy_api_nodes/util/client.py | 68 ++++--- comfy_api_nodes/util/conversions.py | 10 +- comfy_api_nodes/util/download_helpers.py | 188 +++++++++++------- comfy_api_nodes/util/upload_helpers.py | 20 +- pyproject.toml | 9 - 10 files changed, 341 insertions(+), 334 deletions(-) diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 426875d32a16..7c31d95b300a 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -1,32 +1,31 @@ import logging from typing import Optional + import torch from typing_extensions import override +from comfy_api.input import VideoInput +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis import ( - MoonvalleyTextToVideoRequest, + MoonvalleyPromptResponse, MoonvalleyTextToVideoInferenceParams, + MoonvalleyTextToVideoRequest, MoonvalleyVideoToVideoInferenceParams, MoonvalleyVideoToVideoRequest, - MoonvalleyPromptResponse, ) from comfy_api_nodes.util import ( - validate_container_format_is_mp4, - validate_image_dimensions, - download_url_to_video_output, - upload_video_to_comfyapi, - upload_images_to_comfyapi, - sync_op, ApiEndpoint, + download_url_to_video_output, poll_op, - validate_string, + sync_op, trim_video, + upload_images_to_comfyapi, + upload_video_to_comfyapi, + validate_container_format_is_mp4, + validate_image_dimensions, + validate_string, ) -from comfy_api.input import VideoInput -from comfy_api.latest import ComfyExtension, IO - - API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads" API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts" API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video" @@ -103,12 +102,8 @@ def _validate_video_dimensions(width: int, height: int) -> None: } if (width, height) not in supported_resolutions: - supported_list = ", ".join( - [f"{w}x{h}" for w, h in sorted(supported_resolutions)] - ) - raise ValueError( - f"Resolution {width}x{height} not supported. Supported: {supported_list}" - ) + supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)]) + raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") def _validate_and_trim_duration(video: VideoInput) -> VideoInput: @@ -160,6 +155,8 @@ async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromp ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"), response_model=MoonvalleyPromptResponse, status_extractor=lambda r: (r.status if r and r.status else None), + poll_interval=16.0, + max_poll_attempts=240, ) @@ -269,7 +266,7 @@ async def execute( # Get MIME type from tensor - assuming PNG format for image tensors mime_type = "image/png" - image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0] + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0] task_creation_response = await sync_op( cls, endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"), diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index aac69167ec9c..0543d1d0e27c 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -21,7 +21,6 @@ RunwayImageToVideoRequest, RunwayImageToVideoResponse, RunwayTaskStatusResponse as TaskStatusResponse, - RunwayTaskStatusEnum as TaskStatus, RunwayModelEnum as Model, RunwayDurationEnum as Duration, RunwayAspectRatioEnum as AspectRatio, @@ -33,7 +32,18 @@ ReferenceImage, RunwayTextToImageAspectRatioEnum, ) -from comfy_api_nodes.util import image_tensor_pair_to_batch, validate_string, validate_image_dimensions, validate_image_aspect_ratio, upload_images_to_comfyapi, download_url_to_video_output, download_url_to_image_tensor, ApiEndpoint, sync_op, poll_op +from comfy_api_nodes.util import ( + image_tensor_pair_to_batch, + validate_string, + validate_image_dimensions, + validate_image_aspect_ratio, + upload_images_to_comfyapi, + download_url_to_video_output, + download_url_to_image_tensor, + ApiEndpoint, + sync_op, + poll_op, +) from comfy_api.input_impl import VideoFromFile from comfy_api.latest import ComfyExtension, IO @@ -93,20 +103,12 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N async def get_response( - cls: type[IO.ComfyNode], - task_id: str, estimated_duration: Optional[int] = None + cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None ) -> TaskStatusResponse: """Poll the task status until it is finished then get the response.""" return await poll_op( cls, ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"), - completed_statuses=[ - TaskStatus.SUCCEEDED.value, - ], - failed_statuses=[ - TaskStatus.FAILED.value, - TaskStatus.CANCELLED.value, - ], response_model=TaskStatusResponse, status_extractor=lambda r: r.status.value, estimated_duration=estimated_duration, @@ -143,9 +145,9 @@ def define_schema(cls): display_name="Runway Image to Video (Gen3a Turbo)", category="api node/video/Runway", description="Generate a video from a single starting frame using Gen3a Turbo model. " - "Before diving in, review these best practices to ensure that " - "your input selections will set your generation up for success: " - "https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.", + "Before diving in, review these best practices to ensure that " + "your input selections will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.", inputs=[ IO.String.Input( "prompt", @@ -217,11 +219,7 @@ async def execute( duration=Duration(duration), ratio=AspectRatio(ratio), promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ) - ] + root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")] ), ), ) @@ -237,9 +235,9 @@ def define_schema(cls): display_name="Runway Image to Video (Gen4 Turbo)", category="api node/video/Runway", description="Generate a video from a single starting frame using Gen4 Turbo model. " - "Before diving in, review these best practices to ensure that " - "your input selections will set your generation up for success: " - "https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.", + "Before diving in, review these best practices to ensure that " + "your input selections will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.", inputs=[ IO.String.Input( "prompt", @@ -311,11 +309,7 @@ async def execute( duration=Duration(duration), ratio=AspectRatio(ratio), promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ) - ] + root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")] ), ), estimated_duration=AVERAGE_DURATION_FLF_SECONDS, @@ -332,12 +326,12 @@ def define_schema(cls): display_name="Runway First-Last-Frame to Video", category="api node/video/Runway", description="Upload first and last keyframes, draft a prompt, and generate a video. " - "More complex transitions, such as cases where the Last frame is completely different " - "from the First frame, may benefit from the longer 10s duration. " - "This would give the generation more time to smoothly transition between the two inputs. " - "Before diving in, review these best practices to ensure that your input selections " - "will set your generation up for success: " - "https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.", + "More complex transitions, such as cases where the Last frame is completely different " + "from the First frame, may benefit from the longer 10s duration. " + "This would give the generation more time to smoothly transition between the two inputs. " + "Before diving in, review these best practices to ensure that your input selections " + "will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.", inputs=[ IO.String.Input( "prompt", @@ -420,12 +414,8 @@ async def execute( ratio=AspectRatio(ratio), promptImage=RunwayPromptImageObject( root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ), - RunwayPromptImageDetailedObject( - uri=str(download_urls[1]), position="last" - ), + RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first"), + RunwayPromptImageDetailedObject(uri=str(download_urls[1]), position="last"), ] ), ), @@ -443,7 +433,7 @@ def define_schema(cls): display_name="Runway Text to Image", category="api node/image/Runway", description="Generate an image from a text prompt using Runway's Gen 4 model. " - "You can also include reference image to guide the generation.", + "You can also include reference image to guide the generation.", inputs=[ IO.String.Input( "prompt", @@ -527,5 +517,6 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: RunwayTextToImageNode, ] + async def comfy_entrypoint() -> RunwayExtension: return RunwayExtension() diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 9c4d30bc3b24..0e0572f8c7c8 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -1,25 +1,24 @@ import logging from enum import Enum -from typing import Optional, Literal, TypeVar -from typing_extensions import override +from typing import Literal, Optional, TypeVar import torch from pydantic import BaseModel, Field +from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.util import ( - validate_aspect_ratio_closeness, - validate_image_dimensions, - validate_image_aspect_ratio_range, - get_number_of_images, - download_url_to_video_output, - upload_images_to_comfyapi, ApiEndpoint, - sync_op, + download_url_to_video_output, + get_number_of_images, poll_op, + sync_op, + upload_images_to_comfyapi, + validate_aspect_ratio_closeness, + validate_image_aspect_ratio_range, + validate_image_dimensions, ) - VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video" VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video" @@ -28,8 +27,9 @@ R = TypeVar("R") + class VideoModelName(str, Enum): - vidu_q1 = 'viduq1' + vidu_q1 = "viduq1" class AspectRatio(str, Enum): @@ -102,7 +102,7 @@ async def execute_task( ) -> R: response = await sync_op( cls, - endpoint=ApiEndpoint(path=vidu_endpoint,method="POST"), + endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"), response_model=TaskCreationResponse, data=payload, ) @@ -560,5 +560,6 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: ViduStartEndToVideoNode, ] + async def comfy_entrypoint() -> ViduExtension: return ViduExtension() diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 61d50746b2de..2aab3c2ffb5c 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -1,20 +1,22 @@ import re -from typing import Optional, Type, Union -from typing_extensions import override +from typing import Optional import torch from pydantic import BaseModel, Field -from comfy_api.latest import ComfyExtension, Input, IO -from comfy_api_nodes.apis.client import ( +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, - R, - T, + audio_to_base64_string, + download_url_to_image_tensor, + download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, + tensor_to_base64_string, + validate_audio_duration, ) -from comfy_api_nodes.util import get_number_of_images, validate_audio_duration, tensor_to_base64_string, audio_to_base64_string, download_url_to_video_output, download_url_to_image_tensor class Text2ImageInputField(BaseModel): @@ -140,53 +142,7 @@ class VideoTaskStatusResponse(BaseModel): request_id: str = Field(...) -RES_IN_PARENS = re.compile(r'\((\d+)\s*[x×]\s*(\d+)\)') - - -async def process_task( - auth_kwargs: dict[str, str], - url: str, - request_model: Type[T], - response_model: Type[R], - payload: Union[ - Text2ImageTaskCreationRequest, - Image2ImageTaskCreationRequest, - Text2VideoTaskCreationRequest, - Image2VideoTaskCreationRequest, - ], - node_id: str, - estimated_duration: int, - poll_interval: int, -) -> Type[R]: - initial_response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=url, - method=HttpMethod.POST, - request_model=request_model, - response_model=TaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - - if not initial_response.output: - raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") - - return await PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=response_model, - ), - completed_statuses=["SUCCEEDED"], - failed_statuses=["FAILED", "CANCELED", "UNKNOWN"], - status_extractor=lambda x: x.output.task_status, - estimated_duration=estimated_duration, - poll_interval=poll_interval, - node_id=node_id, - auth_kwargs=auth_kwargs, - ).execute() +RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)") class WanTextToImageApi(IO.ComfyNode): @@ -253,7 +209,7 @@ def define_schema(cls): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the result.", + tooltip='Whether to add an "AI generated" watermark to the result.', optional=True, ), ], @@ -280,26 +236,28 @@ async def execute( prompt_extend: bool = True, watermark: bool = True, ): - payload = Text2ImageTaskCreationRequest( - model=model, - input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt), - parameters=Txt2ImageParametersField( - size=f"{width}*{height}", - seed=seed, - prompt_extend=prompt_extend, - watermark=watermark, + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Text2ImageTaskCreationRequest( + model=model, + input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt), + parameters=Txt2ImageParametersField( + size=f"{width}*{height}", + seed=seed, + prompt_extend=prompt_extend, + watermark=watermark, + ), ), ) - response = await process_task( - { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - "/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", - request_model=Text2ImageTaskCreationRequest, + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), response_model=ImageTaskStatusResponse, - payload=payload, - node_id=cls.hidden.unique_id, + status_extractor=lambda x: x.output.task_status, estimated_duration=9, poll_interval=3, ) @@ -314,7 +272,7 @@ def define_schema(cls): display_name="Wan Image to Image", category="api node/image/Wan", description="Generates an image from one or two input images and a text prompt. " - "The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).", + "The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).", inputs=[ IO.Combo.Input( "model", @@ -370,7 +328,7 @@ def define_schema(cls): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the result.", + tooltip='Whether to add an "AI generated" watermark to the result.', optional=True, ), ], @@ -402,28 +360,30 @@ async def execute( raise ValueError(f"Expected 1 or 2 input images, got {n_images}.") images = [] for i in image: - images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096)) - payload = Image2ImageTaskCreationRequest( - model=model, - input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images), - parameters=Image2ImageParametersField( - # size=f"{width}*{height}", - seed=seed, - watermark=watermark, + images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096)) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Image2ImageTaskCreationRequest( + model=model, + input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images), + parameters=Image2ImageParametersField( + # size=f"{width}*{height}", + seed=seed, + watermark=watermark, + ), ), ) - response = await process_task( - { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - "/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", - request_model=Image2ImageTaskCreationRequest, + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), response_model=ImageTaskStatusResponse, - payload=payload, - node_id=cls.hidden.unique_id, + status_extractor=lambda x: x.output.task_status, estimated_duration=42, - poll_interval=3, + poll_interval=4, ) return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url))) @@ -517,7 +477,7 @@ def define_schema(cls): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the result.", + tooltip='Whether to add an "AI generated" watermark to the result.', optional=True, ), ], @@ -551,28 +511,31 @@ async def execute( if audio is not None: validate_audio_duration(audio, 3.0, 29.0) audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") - payload = Text2VideoTaskCreationRequest( - model=model, - input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url), - parameters=Text2VideoParametersField( - size=f"{width}*{height}", - duration=duration, - seed=seed, - audio=generate_audio, - prompt_extend=prompt_extend, - watermark=watermark, + + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Text2VideoTaskCreationRequest( + model=model, + input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url), + parameters=Text2VideoParametersField( + size=f"{width}*{height}", + duration=duration, + seed=seed, + audio=generate_audio, + prompt_extend=prompt_extend, + watermark=watermark, + ), ), ) - response = await process_task( - { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - "/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", - request_model=Text2VideoTaskCreationRequest, + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), response_model=VideoTaskStatusResponse, - payload=payload, - node_id=cls.hidden.unique_id, + status_extractor=lambda x: x.output.task_status, estimated_duration=120 * int(duration / 5), poll_interval=6, ) @@ -661,7 +624,7 @@ def define_schema(cls): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the result.", + tooltip='Whether to add an "AI generated" watermark to the result.', optional=True, ), ], @@ -693,35 +656,37 @@ async def execute( ): if get_number_of_images(image) != 1: raise ValueError("Exactly one input image is required.") - image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000*2000) + image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000) audio_url = None if audio is not None: validate_audio_duration(audio, 3.0, 29.0) audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") - payload = Image2VideoTaskCreationRequest( - model=model, - input=Image2VideoInputField( - prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url - ), - parameters=Image2VideoParametersField( - resolution=resolution, - duration=duration, - seed=seed, - audio=generate_audio, - prompt_extend=prompt_extend, - watermark=watermark, + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Image2VideoTaskCreationRequest( + model=model, + input=Image2VideoInputField( + prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url + ), + parameters=Image2VideoParametersField( + resolution=resolution, + duration=duration, + seed=seed, + audio=generate_audio, + prompt_extend=prompt_extend, + watermark=watermark, + ), ), ) - response = await process_task( - { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - "/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", - request_model=Image2VideoTaskCreationRequest, + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), response_model=VideoTaskStatusResponse, - payload=payload, - node_id=cls.hidden.unique_id, + status_extractor=lambda x: x.output.task_status, estimated_duration=120 * int(duration / 5), poll_interval=6, ) diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index bf6c32d49771..328fe52272fd 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -35,7 +35,7 @@ def default_base_url() -> str: async def sleep_with_interrupt( seconds: float, - node_cls: type[IO.ComfyNode], + node_cls: Optional[type[IO.ComfyNode]], label: Optional[str] = None, start_ts: Optional[float] = None, estimated_total: Optional[int] = None, diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 184e23824451..d6986a79c73b 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from enum import Enum from io import BytesIO -from typing import Any, Callable, Literal, Optional, Type, TypeVar, Union +from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union from urllib.parse import urljoin, urlparse import aiohttp @@ -79,7 +79,7 @@ class _PollUIState: _RETRY_STATUS = {408, 429, 500, 502, 503, 504} COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"] -FAILED_STATUSES = ["cancelled", "failed", "error"] +FAILED_STATUSES = ["cancelled", "canceled", "failed", "error"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"] @@ -97,7 +97,7 @@ async def sync_op( retry_delay: float = 1.0, retry_backoff: float = 2.0, wait_label: str = "Waiting for server", - estimated_total: Optional[int] = None, + estimated_duration: Optional[int] = None, final_label_on_success: Optional[str] = "Completed", progress_origin_ts: Optional[float] = None, monitor_progress: bool = True, @@ -114,7 +114,7 @@ async def sync_op( retry_delay=retry_delay, retry_backoff=retry_backoff, wait_label=wait_label, - estimated_total=estimated_total, + estimated_duration=estimated_duration, as_binary=False, final_label_on_success=final_label_on_success, progress_origin_ts=progress_origin_ts, @@ -130,7 +130,7 @@ async def poll_op( poll_endpoint: ApiEndpoint, *, response_model: Type[M], - status_extractor: Callable[[M], Optional[str]], + status_extractor: Callable[[M], Optional[Union[str, int]]], progress_extractor: Optional[Callable[[M], Optional[int]]] = None, price_extractor: Optional[Callable[[M], Optional[float]]] = None, completed_statuses: Optional[list[Union[str, int]]] = None, @@ -183,7 +183,7 @@ async def sync_op_raw( retry_delay: float = 1.0, retry_backoff: float = 2.0, wait_label: str = "Waiting for server", - estimated_total: Optional[int] = None, + estimated_duration: Optional[int] = None, as_binary: bool = False, final_label_on_success: Optional[str] = "Completed", progress_origin_ts: Optional[float] = None, @@ -212,7 +212,7 @@ async def sync_op_raw( retry_backoff=retry_backoff, wait_label=wait_label, monitor_progress=monitor_progress, - estimated_total=estimated_total, + estimated_total=estimated_duration, final_label_on_success=final_label_on_success, progress_origin_ts=progress_origin_ts, ) @@ -223,7 +223,7 @@ async def poll_op_raw( cls: type[IO.ComfyNode], poll_endpoint: ApiEndpoint, *, - status_extractor: Callable[[dict[str, Any]], Optional[str]], + status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]], progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None, price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, completed_statuses: Optional[list[Union[str, int]]] = None, @@ -247,9 +247,9 @@ async def poll_op_raw( Returns the final JSON response from the poll endpoint. """ - completed_states = COMPLETED_STATUSES if completed_statuses is None else completed_statuses - failed_states = FAILED_STATUSES if failed_statuses is None else failed_statuses - queued_states = QUEUED_STATUSES if queued_statuses is None else queued_statuses + completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses) + failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses) + queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses) started = time.monotonic() consumed_attempts = 0 # counts only non-queued polls @@ -271,7 +271,7 @@ async def _ticker(): ) _display_time_progress( cls, - label=state.status_label, + status=state.status_label, elapsed_seconds=int(now - state.started), estimated_total=state.estimated_duration, price=state.price, @@ -294,7 +294,7 @@ async def _ticker(): retry_delay=retry_delay_per_poll, retry_backoff=retry_backoff_per_poll, wait_label="Checking", - estimated_total=None, + estimated_duration=None, as_binary=False, final_label_on_success=None, monitor_progress=False, @@ -310,7 +310,7 @@ async def _ticker(): timeout=cancel_timeout, max_retries=0, wait_label="Cancelling task", - estimated_total=None, + estimated_duration=None, as_binary=False, final_label_on_success=None, monitor_progress=False, @@ -318,7 +318,7 @@ async def _ticker(): raise try: - status = status_extractor(resp_json) + status = _normalize_status_value(status_extractor(resp_json)) except Exception as e: logging.error("Status extraction failed: %s", e) status = None @@ -360,7 +360,7 @@ async def _ticker(): _display_time_progress( cls, - label=status if status else "Completed", + status=status if status else "Completed", elapsed_seconds=int(now_ts - started), estimated_total=estimated_duration, price=state.price, @@ -385,7 +385,7 @@ async def _ticker(): timeout=cancel_timeout, max_retries=0, wait_label="Cancelling task", - estimated_total=None, + estimated_duration=None, as_binary=False, final_label_on_success=None, monitor_progress=False, @@ -414,12 +414,12 @@ def _display_text( node_cls: type[IO.ComfyNode], text: Optional[str], *, - status: Optional[str] = None, + status: Optional[Union[str, int]] = None, price: Optional[float] = None, ) -> None: display_lines: list[str] = [] if status: - display_lines.append(f"Status: {status.capitalize()}") + display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") if price is not None: display_lines.append(f"Price: ${float(price):,.4f}") if text is not None: @@ -430,7 +430,7 @@ def _display_text( def _display_time_progress( node_cls: type[IO.ComfyNode], - label: str, + status: Optional[Union[str, int]], elapsed_seconds: int, estimated_total: Optional[int] = None, *, @@ -444,7 +444,7 @@ def _display_time_progress( time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" else: time_line = f"Time elapsed: {int(elapsed_seconds)}s" - _display_text(node_cls, time_line, status=label, price=price) + _display_text(node_cls, time_line, status=status, price=price) async def _diagnose_connectivity() -> dict[str, bool]: @@ -640,7 +640,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): with contextlib.suppress(Exception): file_value.seek(0) form.add_field(field_name, file_value, filename=filename, content_type=content_type) - payload_kw["data"] = form # do not send body on GET + payload_kw["data"] = form elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET": payload_headers["Content-Type"] = "application/x-www-form-urlencoded" payload_kw["data"] = cfg.data or {} @@ -660,7 +660,6 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): except Exception as _log_e: logging.debug("[DEBUG] request logging failed: %s", _log_e) - # Compose the HTTP request coroutine req_coro = sess.request(method, url, params=params, **payload_kw) req_task = asyncio.create_task(req_coro) @@ -684,7 +683,6 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): body = await resp.json() except (ContentTypeError, json.JSONDecodeError): body = await resp.text() - # Retryable? if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries: logging.warning( "HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).", @@ -733,9 +731,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): logging.debug("[DEBUG] response logging failed: %s", _log_e) raise Exception(msg) - # Success if expect_binary: - # Read stream in chunks so that cancellation is fast when user interrupts buff = bytearray() last_tick = time.monotonic() async for chunk in resp.content.iter_chunked(64 * 1024): @@ -794,7 +790,6 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): logging.debug("Polling was interrupted by user") raise except (ClientError, asyncio.TimeoutError, socket.gaierror) as e: - # Retry transient connection issues if attempt <= cfg.max_retries: logging.warning( "Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s", @@ -873,7 +868,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success: _display_time_progress( cfg.node_cls, - label=cfg.final_label_on_success, + status=cfg.final_label_on_success, elapsed_seconds=( final_elapsed_seconds if final_elapsed_seconds is not None @@ -926,3 +921,20 @@ def _wrapped(d: dict[str, Any]) -> Any: raise return _wrapped + + +def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]: + if not values: + return set() + out: set[Union[str, int]] = set() + for v in values: + nv = _normalize_status_value(v) + if nv is not None: + out.add(nv) + return out + + +def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]: + if isinstance(val, str): + return val.strip().lower() + return val diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 3cf81459e173..10cd1051b4d7 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -11,9 +11,7 @@ from PIL import Image from comfy.utils import common_upscale -from comfy_api.input import VideoInput -from comfy_api.input.basic_types import AudioInput -from comfy_api.latest import InputImpl +from comfy_api.latest import Input, InputImpl from ._helpers import mimetype_to_extension @@ -165,7 +163,7 @@ def tensor_to_data_uri( return f"data:{mime_type};base64,{base64_string}" -def audio_to_base64_string(audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac") -> str: +def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", codec_name: str = "aac") -> str: """Converts an audio input to a base64 string.""" sample_rate: int = audio["sample_rate"] waveform: torch.Tensor = audio["waveform"] @@ -232,7 +230,7 @@ def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: return audio_data_np -def audio_input_to_mp3(audio: AudioInput) -> BytesIO: +def audio_input_to_mp3(audio: Input.Audio) -> BytesIO: waveform = audio["waveform"].cpu() output_buffer = BytesIO() @@ -255,7 +253,7 @@ def audio_input_to_mp3(audio: AudioInput) -> BytesIO: return output_buffer -def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: +def trim_video(video: Input.Video, duration_sec: float) -> Input.Video: """ Returns a new VideoInput object trimmed from the beginning to the specified duration, using av to avoid loading entire video into memory. diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index bffe9bc2fd65..d0b65424bd1d 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -1,7 +1,5 @@ import asyncio import contextlib -import logging -import time import uuid from io import BytesIO from pathlib import Path @@ -13,10 +11,15 @@ from aiohttp.client_exceptions import ClientError, ContentTypeError from comfy_api.input_impl import VideoFromFile -from comfy_api.latest import IO as ComfyIO +from comfy_api.latest import IO as COMFY_IO from comfy_api_nodes.apis import request_logger -from ._helpers import default_base_url, get_auth_header, is_processing_interrupted +from ._helpers import ( + default_base_url, + get_auth_header, + is_processing_interrupted, + sleep_with_interrupt, +) from .client import _diagnose_connectivity from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted from .conversions import bytesio_to_image_tensor @@ -32,7 +35,7 @@ async def download_url_to_bytesio( max_retries: int = 3, retry_delay: float = 1.0, retry_backoff: float = 2.0, - cls: type[ComfyIO.ComfyNode] = None, + cls: type[COMFY_IO.ComfyNode] = None, ) -> None: """Stream-download a URL to `dest`. @@ -52,7 +55,8 @@ async def download_url_to_bytesio( attempt = 0 delay = retry_delay - headers = {} + headers: dict[str, str] = {} + if url.startswith("/proxy/"): if cls is None: raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.") @@ -66,69 +70,112 @@ async def download_url_to_bytesio( is_path_sink = isinstance(dest, (str, Path)) fhandle = None + session: Optional[aiohttp.ClientSession] = None + stop_evt: Optional[asyncio.Event] = None + monitor_task: Optional[asyncio.Task] = None + req_task: Optional[asyncio.Task] = None + try: with contextlib.suppress(Exception): request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url) - async with aiohttp.ClientSession(timeout=timeout_cfg) as session: - async with session.get(url, headers=headers) as resp: - if resp.status >= 400: - with contextlib.suppress(Exception): - try: - body = await resp.json() - except (ContentTypeError, ValueError): - text = await resp.text() - body = text if len(text) <= 4096 else f"[text {len(text)} bytes]" - request_logger.log_request_response( - operation_id=op_id, - request_method="GET", - request_url=url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content=body, - error_message=f"HTTP {resp.status}", - ) - - if resp.status in _RETRY_STATUS and attempt <= max_retries: - await _sleep_with_cancel(delay) - delay *= retry_backoff - continue - raise Exception(f"Failed to download (HTTP {resp.status}).") - - if is_path_sink: - p = Path(str(dest)) - with contextlib.suppress(Exception): - p.parent.mkdir(parents=True, exist_ok=True) - fhandle = open(p, "wb") - sink = fhandle - else: - sink = dest # BytesIO or file-like - - written = 0 - async for chunk in resp.content.iter_chunked(1024 * 1024): - sink.write(chunk) - written += len(chunk) + session = aiohttp.ClientSession(timeout=timeout_cfg) + stop_evt = asyncio.Event() + + async def _monitor(): + try: + while not stop_evt.is_set(): if is_processing_interrupted(): - raise ProcessingInterrupted("Task cancelled") + return + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return + + monitor_task = asyncio.create_task(_monitor()) + + req_task = asyncio.create_task(session.get(url, headers=headers)) + done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task in done and req_task in pending: + req_task.cancel() + with contextlib.suppress(Exception): + await req_task + raise ProcessingInterrupted("Task cancelled") - if isinstance(dest, BytesIO): - with contextlib.suppress(Exception): - dest.seek(0) + try: + resp = await req_task + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + async with resp: + if resp.status >= 400: with contextlib.suppress(Exception): + try: + body = await resp.json() + except (ContentTypeError, ValueError): + text = await resp.text() + body = text if len(text) <= 4096 else f"[text {len(text)} bytes]" request_logger.log_request_response( operation_id=op_id, request_method="GET", request_url=url, response_status_code=resp.status, response_headers=dict(resp.headers), - response_content=f"[streamed {written} bytes to dest]", + response_content=body, + error_message=f"HTTP {resp.status}", ) - return - except ProcessingInterrupted: - logging.debug("Download was interrupted by user") - raise + if resp.status in _RETRY_STATUS and attempt <= max_retries: + await sleep_with_interrupt(delay, cls, None, None, None) + delay *= retry_backoff + continue + raise Exception(f"Failed to download (HTTP {resp.status}).") + + if is_path_sink: + p = Path(str(dest)) + with contextlib.suppress(Exception): + p.parent.mkdir(parents=True, exist_ok=True) + fhandle = open(p, "wb") + sink = fhandle + else: + sink = dest # BytesIO or file-like + + written = 0 + while True: + try: + chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0) + except asyncio.TimeoutError: + chunk = b"" + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + + if is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + + if not chunk: + if resp.content.at_eof(): + break + continue + + sink.write(chunk) + written += len(chunk) + + if isinstance(dest, BytesIO): + with contextlib.suppress(Exception): + dest.seek(0) + + with contextlib.suppress(Exception): + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=f"[streamed {written} bytes to dest]", + ) + return + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None except (ClientError, asyncio.TimeoutError) as e: if attempt <= max_retries: with contextlib.suppress(Exception): @@ -138,7 +185,7 @@ async def download_url_to_bytesio( request_url=url, error_message=f"{type(e).__name__}: {str(e)} (will retry)", ) - await _sleep_with_cancel(delay) + await sleep_with_interrupt(delay, cls, None, None, None) delay *= retry_backoff continue @@ -149,8 +196,21 @@ async def download_url_to_bytesio( ) from e raise ApiServerError("The remote service appears unreachable at this time.") from e finally: - with contextlib.suppress(Exception): - if fhandle: + if stop_evt is not None: + stop_evt.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if req_task and not req_task.done(): + req_task.cancel() + with contextlib.suppress(Exception): + await req_task + if session: + with contextlib.suppress(Exception): + await session.close() + if fhandle: + with contextlib.suppress(Exception): fhandle.flush() fhandle.close() @@ -159,7 +219,7 @@ async def download_url_to_image_tensor( url: str, *, timeout: float = None, - cls: type[ComfyIO.ComfyNode] = None, + cls: type[COMFY_IO.ComfyNode] = None, ) -> torch.Tensor: """Downloads an image from a URL and returns a [B, H, W, C] tensor.""" result = BytesIO() @@ -171,7 +231,7 @@ async def download_url_to_video_output( video_url: str, *, timeout: float = None, - cls: type[ComfyIO.ComfyNode] = None, + cls: type[COMFY_IO.ComfyNode] = None, ) -> VideoFromFile: """Downloads a video from a URL and returns a `VIDEO` output.""" result = BytesIO() @@ -186,15 +246,3 @@ def _generate_operation_id(method: str, url: str, attempt: int) -> str: except Exception: slug = "download" return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" - - -async def _sleep_with_cancel(seconds: float) -> None: - """Sleep in 1s slices while checking for interruption.""" - end = time.monotonic() + seconds - while True: - if is_processing_interrupted(): - raise ProcessingInterrupted("Task cancelled") - now = time.monotonic() - if now >= end: - return - await asyncio.sleep(min(1.0, end - now)) diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index fa9d10b8e802..a345d451d4cd 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -11,9 +11,7 @@ import torch from pydantic import BaseModel, Field -from comfy_api.input.basic_types import AudioInput -from comfy_api.input.video_types import VideoInput -from comfy_api.latest import IO +from comfy_api.latest import IO, Input from comfy_api.util import VideoCodec, VideoContainer from comfy_api_nodes.apis import request_logger @@ -72,7 +70,7 @@ async def upload_images_to_comfyapi( async def upload_audio_to_comfyapi( cls: type[IO.ComfyNode], - audio: AudioInput, + audio: Input.Audio, *, container_format: str = "mp4", codec_name: str = "aac", @@ -92,7 +90,7 @@ async def upload_audio_to_comfyapi( async def upload_video_to_comfyapi( cls: type[IO.ComfyNode], - video: VideoInput, + video: Input.Video, *, container: VideoContainer = VideoContainer.MP4, codec: VideoCodec = VideoCodec.H264, @@ -104,8 +102,8 @@ async def upload_video_to_comfyapi( """ if max_duration is not None: try: - actual_duration = video.duration_seconds - if actual_duration is not None and actual_duration > max_duration: + actual_duration = video.get_duration() + if actual_duration > max_duration: raise ValueError( f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)." ) @@ -244,7 +242,11 @@ async def _monitor(): req_task.cancel() raise ProcessingInterrupted("Upload cancelled") - resp = await req_task + try: + resp = await req_task + except asyncio.CancelledError: + raise ProcessingInterrupted("Upload cancelled") from None + async with resp: if resp.status >= 400: with contextlib.suppress(Exception): @@ -286,6 +288,8 @@ async def _monitor(): except Exception as e: logging.debug("[DEBUG] upload response logging failed: %s", e) return + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None except (aiohttp.ClientError, asyncio.TimeoutError) as e: if attempt <= max_retries: with contextlib.suppress(Exception): diff --git a/pyproject.toml b/pyproject.toml index 29023d517bb1..fcbcb3dd919e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,12 +69,3 @@ messages_control.disable = [ "no-else-return", "unused-variable", ] - - -[tool.black] -line-length = 120 -preview = true - - -[tool.isort] -profile = "black" From fab58ddfa10e9e2fa8b2b959b387795aaafe59bc Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Thu, 23 Oct 2025 09:50:39 +0300 Subject: [PATCH 5/8] fix(auth): do not leak authentification for the absolute urls --- comfy_api_nodes/util/client.py | 25 +++++++++++------------- comfy_api_nodes/util/download_helpers.py | 7 ++++--- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index d6986a79c73b..1c37d16868d3 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -482,18 +482,6 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]: raise ValueError("files tuple must be (filename, file[, content_type])") -def _join_url(base_url: str, path: str) -> str: - return urljoin(base_url.rstrip("/") + "/", path.lstrip("/")) - - -def _merge_headers(node_cls: type[IO.ComfyNode], endpoint_headers: dict[str, str]) -> dict[str, str]: - headers = {"Accept": "*/*"} - headers.update(get_auth_header(node_cls)) - if endpoint_headers: - headers.update(endpoint_headers) - return headers - - def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]: params = dict(endpoint_params or {}) if method.upper() == "GET" and data: @@ -566,7 +554,11 @@ def _snapshot_request_body_for_logging( async def _request_base(cfg: _RequestConfig, expect_binary: bool): """Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors.""" - url = _join_url(default_base_url(), cfg.endpoint.path) + url = cfg.endpoint.path + parsed_url = urlparse(url) + if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? + url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) + method = cfg.endpoint.method params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None) @@ -598,7 +590,12 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float): operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) - payload_headers = _merge_headers(cfg.node_cls, cfg.endpoint.headers) + payload_headers = {"Accept": "*/*"} + if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? + payload_headers.update(get_auth_header(cfg.node_cls)) + if cfg.endpoint.headers: + payload_headers.update(cfg.endpoint.headers) + payload_kw: dict[str, Any] = {"headers": payload_headers} if method == "GET": payload_headers.pop("Content-Type", None) diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index d0b65424bd1d..055e690de4e5 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -4,7 +4,7 @@ from io import BytesIO from pathlib import Path from typing import IO, Optional, Union -from urllib.parse import urlparse +from urllib.parse import urljoin, urlparse import aiohttp import torch @@ -57,10 +57,11 @@ async def download_url_to_bytesio( delay = retry_delay headers: dict[str, str] = {} - if url.startswith("/proxy/"): + parsed_url = urlparse(url) + if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? if cls is None: raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.") - url = default_base_url().rstrip("/") + url + url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) headers = get_auth_header(cls) while True: From 6dadfa2cb427247b2fb43c3cfc3b380555f404e9 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Thu, 23 Oct 2025 09:51:50 +0300 Subject: [PATCH 6/8] convert BFL API nodes to use new API client; remove deprecated BFL nodes --- comfy_api_nodes/apis/bfl_api.py | 51 +-- comfy_api_nodes/nodes_bfl.py | 606 ++++++++------------------------ 2 files changed, 152 insertions(+), 505 deletions(-) diff --git a/comfy_api_nodes/apis/bfl_api.py b/comfy_api_nodes/apis/bfl_api.py index 0e90aef7c681..0fc8c060767c 100644 --- a/comfy_api_nodes/apis/bfl_api.py +++ b/comfy_api_nodes/apis/bfl_api.py @@ -50,44 +50,6 @@ class BFLFluxFillImageRequest(BaseModel): mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.') -class BFLFluxCannyImageRequest(BaseModel): - prompt: str = Field(..., description='Text prompt for image generation') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection') - canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection') - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') - guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided') - preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step') - - -class BFLFluxDepthImageRequest(BaseModel): - prompt: str = Field(..., description='Text prompt for image generation') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') - guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided') - preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step') - - class BFLFluxProGenerateRequest(BaseModel): prompt: str = Field(..., description='The text prompt for image generation.') prompt_upsampling: Optional[bool] = Field( @@ -160,15 +122,8 @@ class BFLStatus(str, Enum): error = "Error" -class BFLFluxProStatusResponse(BaseModel): +class BFLFluxStatusResponse(BaseModel): id: str = Field(..., description="The unique identifier for the generation task.") status: BFLStatus = Field(..., description="The status of the task.") - result: Optional[Dict[str, Any]] = Field( - None, description="The result of the task (null if not completed)." - ) - progress: confloat(ge=0.0, le=1.0) = Field( - ..., description="The progress of the task (0.0 to 1.0)." - ) - details: Optional[Dict[str, Any]] = Field( - None, description="Additional details about the task (null if not available)." - ) + result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).") + progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0) diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 3e83eb127df7..baa74fd529d8 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,39 +1,32 @@ -import asyncio -import io from inspect import cleandoc -from typing import Union, Optional +from typing import Optional + +import torch from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO + +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apinode_utils import ( + resize_mask_to_image, + validate_aspect_ratio, +) from comfy_api_nodes.apis.bfl_api import ( - BFLStatus, BFLFluxExpandImageRequest, BFLFluxFillImageRequest, - BFLFluxCannyImageRequest, - BFLFluxDepthImageRequest, - BFLFluxProGenerateRequest, BFLFluxKontextProGenerateRequest, - BFLFluxProUltraGenerateRequest, + BFLFluxProGenerateRequest, BFLFluxProGenerateResponse, + BFLFluxProUltraGenerateRequest, + BFLFluxStatusResponse, + BFLStatus, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, + download_url_to_image_tensor, + poll_op, + sync_op, + tensor_to_base64_string, + validate_string, ) -from comfy_api_nodes.apinode_utils import ( - validate_aspect_ratio, - process_image_response, - resize_mask_to_image, -) -from comfy_api_nodes.util import validate_string, downscale_image_tensor - -import numpy as np -from PIL import Image -import aiohttp -import torch -import base64 -import time -from server import PromptServer def convert_mask_to_image(mask: torch.Tensor): @@ -41,95 +34,10 @@ def convert_mask_to_image(mask: torch.Tensor): Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. """ mask = mask.unsqueeze(-1) - mask = torch.cat([mask]*3, dim=-1) + mask = torch.cat([mask] * 3, dim=-1) return mask -async def handle_bfl_synchronous_operation( - operation: SynchronousOperation, - timeout_bfl_calls=360, - node_id: Union[str, None] = None, -): - response_api: BFLFluxProGenerateResponse = await operation.execute() - return await _poll_until_generated( - response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id - ) - - -async def _poll_until_generated( - polling_url: str, timeout=360, node_id: Union[str, None] = None -): - # used bfl-comfy-nodes to verify code implementation: - # https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main - start_time = time.time() - retries_404 = 0 - max_retries_404 = 5 - retry_404_seconds = 2 - retry_202_seconds = 2 - retry_pending_seconds = 1 - - async with aiohttp.ClientSession() as session: - # NOTE: should True loop be replaced with checking if workflow has been interrupted? - while True: - if node_id: - time_elapsed = time.time() - start_time - PromptServer.instance.send_progress_text( - f"Generating ({time_elapsed:.0f}s)", node_id - ) - - async with session.get(polling_url) as response: - if response.status == 200: - result = await response.json() - if result["status"] == BFLStatus.ready: - img_url = result["result"]["sample"] - if node_id: - PromptServer.instance.send_progress_text( - f"Result URL: {img_url}", node_id - ) - async with session.get(img_url) as img_resp: - return process_image_response(await img_resp.content.read()) - elif result["status"] in [ - BFLStatus.request_moderated, - BFLStatus.content_moderated, - ]: - status = result["status"] - raise Exception( - f"BFL API did not return an image due to: {status}." - ) - elif result["status"] == BFLStatus.error: - raise Exception(f"BFL API encountered an error: {result}.") - elif result["status"] == BFLStatus.pending: - await asyncio.sleep(retry_pending_seconds) - continue - elif response.status == 404: - if retries_404 < max_retries_404: - retries_404 += 1 - await asyncio.sleep(retry_404_seconds) - continue - raise Exception( - f"BFL API could not find task after {max_retries_404} tries." - ) - elif response.status == 202: - await asyncio.sleep(retry_202_seconds) - elif time.time() - start_time > timeout: - raise Exception( - f"BFL API experienced a timeout; could not return request under {timeout} seconds." - ) - else: - raise Exception(f"BFL API encountered an error: {response.json()}") - -def convert_image_to_base64(image: torch.Tensor): - scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048) - # remove batch dimension if present - if len(scaled_image.shape) > 3: - scaled_image = scaled_image[0] - image_np = (scaled_image.numpy() * 255).astype(np.uint8) - img = Image.fromarray(image_np) - img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format="PNG") - return base64.b64encode(img_byte_arr.getvalue()).decode() - - class FluxProUltraImageNode(IO.ComfyNode): """ Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. @@ -157,7 +65,9 @@ def define_schema(cls) -> IO.Schema: IO.Boolean.Input( "prompt_upsampling", default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), IO.Int.Input( "seed", @@ -219,22 +129,19 @@ async def execute( cls, prompt: str, aspect_ratio: str, - prompt_upsampling=False, - raw=False, - seed=0, - image_prompt=None, - image_prompt_strength=0.1, + prompt_upsampling: bool = False, + raw: bool = False, + seed: int = 0, + image_prompt: Optional[torch.Tensor] = None, + image_prompt_strength: float = 0.1, ) -> IO.NodeOutput: if image_prompt is None: validate_string(prompt, strip_whitespace=False) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.1-ultra/generate", - method=HttpMethod.POST, - request_model=BFLFluxProUltraGenerateRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxProUltraGenerateRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-pro-1.1-ultra/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxProUltraGenerateRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, seed=seed, @@ -246,22 +153,26 @@ async def execute( maximum_ratio_str=cls.MAXIMUM_RATIO_STR, ), raw=raw, - image_prompt=( - image_prompt - if image_prompt is None - else convert_image_to_base64(image_prompt) - ), - image_prompt_strength=( - None if image_prompt is None else round(image_prompt_strength, 2) - ), + image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)), + image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)), ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class FluxKontextProImageNode(IO.ComfyNode): @@ -346,7 +257,7 @@ async def execute( aspect_ratio: str, guidance: float, steps: int, - input_image: Optional[torch.Tensor]=None, + input_image: Optional[torch.Tensor] = None, seed=0, prompt_upsampling=False, ) -> IO.NodeOutput: @@ -359,33 +270,36 @@ async def execute( ) if input_image is None: validate_string(prompt, strip_whitespace=False) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=cls.BFL_PATH, - method=HttpMethod.POST, - request_model=BFLFluxKontextProGenerateRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxKontextProGenerateRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path=cls.BFL_PATH, method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxKontextProGenerateRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, guidance=round(guidance, 1), steps=steps, seed=seed, aspect_ratio=aspect_ratio, - input_image=( - input_image - if input_image is None - else convert_image_to_base64(input_image) - ) + input_image=(input_image if input_image is None else tensor_to_base64_string(input_image)), ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class FluxKontextMaxImageNode(FluxKontextProImageNode): @@ -421,7 +335,9 @@ def define_schema(cls) -> IO.Schema: IO.Boolean.Input( "prompt_upsampling", default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), IO.Int.Input( "width", @@ -480,20 +396,15 @@ async def execute( image_prompt=None, # image_prompt_strength=0.1, ) -> IO.NodeOutput: - image_prompt = ( - image_prompt - if image_prompt is None - else convert_image_to_base64(image_prompt) - ) - - operation = SynchronousOperation( - endpoint=ApiEndpoint( + image_prompt = image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt) + initial_response = await sync_op( + cls, + ApiEndpoint( path="/proxy/bfl/flux-pro-1.1/generate", - method=HttpMethod.POST, - request_model=BFLFluxProGenerateRequest, - response_model=BFLFluxProGenerateResponse, + method="POST", ), - request=BFLFluxProGenerateRequest( + response_model=BFLFluxProGenerateResponse, + data=BFLFluxProGenerateRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, width=width, @@ -501,13 +412,23 @@ async def execute( seed=seed, image_prompt=image_prompt, ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class FluxProExpandNode(IO.ComfyNode): @@ -533,7 +454,9 @@ def define_schema(cls) -> IO.Schema: IO.Boolean.Input( "prompt_upsampling", default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), IO.Int.Input( "top", @@ -609,16 +532,11 @@ async def execute( guidance: float, seed=0, ) -> IO.NodeOutput: - image = convert_image_to_base64(image) - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-expand/generate", - method=HttpMethod.POST, - request_model=BFLFluxExpandImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxExpandImageRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-expand/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxExpandImageRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, top=top, @@ -628,16 +546,25 @@ async def execute( steps=steps, guidance=guidance, seed=seed, - image=image, + image=tensor_to_base64_string(image), ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) - + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class FluxProFillNode(IO.ComfyNode): @@ -664,7 +591,9 @@ def define_schema(cls) -> IO.Schema: IO.Boolean.Input( "prompt_upsampling", default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), IO.Float.Input( "guidance", @@ -711,272 +640,37 @@ async def execute( ) -> IO.NodeOutput: # prepare mask mask = resize_mask_to_image(mask, image) - mask = convert_image_to_base64(convert_mask_to_image(mask)) - # make sure image will have alpha channel removed - image = convert_image_to_base64(image[:, :, :, :3]) - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-fill/generate", - method=HttpMethod.POST, - request_model=BFLFluxFillImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxFillImageRequest( + mask = tensor_to_base64_string(convert_mask_to_image(mask)) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-fill/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxFillImageRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, steps=steps, guidance=guidance, seed=seed, - image=image, + image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed mask=mask, ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) - - -class FluxProCannyNode(IO.ComfyNode): - """ - Generate image using a control image (canny). - """ - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="FluxProCannyNode", - display_name="Flux.1 Canny Control Image", - category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), - inputs=[ - IO.Image.Input("control_image"), - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Prompt for the image generation", - ), - IO.Boolean.Input( - "prompt_upsampling", - default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - ), - IO.Float.Input( - "canny_low_threshold", - default=0.1, - min=0.01, - max=0.99, - step=0.01, - tooltip="Low threshold for Canny edge detection; ignored if skip_processing is True", - ), - IO.Float.Input( - "canny_high_threshold", - default=0.4, - min=0.01, - max=0.99, - step=0.01, - tooltip="High threshold for Canny edge detection; ignored if skip_processing is True", - ), - IO.Boolean.Input( - "skip_preprocessing", - default=False, - tooltip="Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.", - ), - IO.Float.Input( - "guidance", - default=30, - min=1, - max=100, - tooltip="Guidance strength for the image generation process", - ), - IO.Int.Input( - "steps", - default=50, - min=15, - max=50, - tooltip="Number of steps for the image generation process", - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=0xFFFFFFFFFFFFFFFF, - control_after_generate=True, - tooltip="The random seed used for creating the noise.", - ), - ], - outputs=[IO.Image.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - ) - - @classmethod - async def execute( - cls, - control_image: torch.Tensor, - prompt: str, - prompt_upsampling: bool, - canny_low_threshold: float, - canny_high_threshold: float, - skip_preprocessing: bool, - steps: int, - guidance: float, - seed=0, - ) -> IO.NodeOutput: - control_image = convert_image_to_base64(control_image[:, :, :, :3]) - preprocessed_image = None - - # scale canny threshold between 0-500, to match BFL's API - def scale_value(value: float, min_val=0, max_val=500): - return min_val + value * (max_val - min_val) - canny_low_threshold = int(round(scale_value(canny_low_threshold))) - canny_high_threshold = int(round(scale_value(canny_high_threshold))) - - - if skip_preprocessing: - preprocessed_image = control_image - control_image = None - canny_low_threshold = None - canny_high_threshold = None - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-canny/generate", - method=HttpMethod.POST, - request_model=BFLFluxCannyImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxCannyImageRequest( - prompt=prompt, - prompt_upsampling=prompt_upsampling, - steps=steps, - guidance=guidance, - seed=seed, - control_image=control_image, - canny_low_threshold=canny_low_threshold, - canny_high_threshold=canny_high_threshold, - preprocessed_image=preprocessed_image, - ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) - - -class FluxProDepthNode(IO.ComfyNode): - """ - Generate image using a control image (depth). - """ - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="FluxProDepthNode", - display_name="Flux.1 Depth Control Image", - category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), - inputs=[ - IO.Image.Input("control_image"), - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Prompt for the image generation", - ), - IO.Boolean.Input( - "prompt_upsampling", - default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - ), - IO.Boolean.Input( - "skip_preprocessing", - default=False, - tooltip="Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.", - ), - IO.Float.Input( - "guidance", - default=15, - min=1, - max=100, - tooltip="Guidance strength for the image generation process", - ), - IO.Int.Input( - "steps", - default=50, - min=15, - max=50, - tooltip="Number of steps for the image generation process", - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=0xFFFFFFFFFFFFFFFF, - control_after_generate=True, - tooltip="The random seed used for creating the noise.", - ), - ], - outputs=[IO.Image.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, ], - is_api_node=True, - ) - - @classmethod - async def execute( - cls, - control_image: torch.Tensor, - prompt: str, - prompt_upsampling: bool, - skip_preprocessing: bool, - steps: int, - guidance: float, - seed=0, - ) -> IO.NodeOutput: - control_image = convert_image_to_base64(control_image[:,:,:,:3]) - preprocessed_image = None - - if skip_preprocessing: - preprocessed_image = control_image - control_image = None - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-depth/generate", - method=HttpMethod.POST, - request_model=BFLFluxDepthImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxDepthImageRequest( - prompt=prompt, - prompt_upsampling=prompt_upsampling, - steps=steps, - guidance=guidance, - seed=seed, - control_image=control_image, - preprocessed_image=preprocessed_image, - ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, + queued_statuses=[], ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class BFLExtension(ComfyExtension): @@ -989,8 +683,6 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: FluxKontextMaxImageNode, FluxProExpandNode, FluxProFillNode, - FluxProCannyNode, - FluxProDepthNode, ] From b7916fdf9be209c74f0c2385e0b0425f6726b956 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Thu, 23 Oct 2025 15:41:49 +0300 Subject: [PATCH 7/8] converted Google Veo nodes --- comfy_api_nodes/apis/veo_api.py | 111 ++++++++++++++++++++ comfy_api_nodes/nodes_veo2.py | 173 ++++++++++---------------------- comfy_api_nodes/util/client.py | 4 + 3 files changed, 167 insertions(+), 121 deletions(-) create mode 100644 comfy_api_nodes/apis/veo_api.py diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo_api.py new file mode 100644 index 000000000000..a55137afbec1 --- /dev/null +++ b/comfy_api_nodes/apis/veo_api.py @@ -0,0 +1,111 @@ +from typing import Optional, Union +from enum import Enum + +from pydantic import BaseModel, Field + + +class Image2(BaseModel): + bytesBase64Encoded: str + gcsUri: Optional[str] = None + mimeType: Optional[str] = None + + +class Image3(BaseModel): + bytesBase64Encoded: Optional[str] = None + gcsUri: str + mimeType: Optional[str] = None + + +class Instance1(BaseModel): + image: Optional[Union[Image2, Image3]] = Field( + None, description='Optional image to guide video generation' + ) + prompt: str = Field(..., description='Text description of the video') + + +class PersonGeneration1(str, Enum): + ALLOW = 'ALLOW' + BLOCK = 'BLOCK' + + +class Parameters1(BaseModel): + aspectRatio: Optional[str] = Field(None, examples=['16:9']) + durationSeconds: Optional[int] = None + enhancePrompt: Optional[bool] = None + generateAudio: Optional[bool] = Field( + None, + description='Generate audio for the video. Only supported by veo 3 models.', + ) + negativePrompt: Optional[str] = None + personGeneration: Optional[PersonGeneration1] = None + sampleCount: Optional[int] = None + seed: Optional[int] = None + storageUri: Optional[str] = Field( + None, description='Optional Cloud Storage URI to upload the video' + ) + + +class VeoGenVidRequest(BaseModel): + instances: Optional[list[Instance1]] = None + parameters: Optional[Parameters1] = None + + +class VeoGenVidResponse(BaseModel): + name: str = Field( + ..., + description='Operation resource name', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8' + ], + ) + + +class VeoGenVidPollRequest(BaseModel): + operationName: str = Field( + ..., + description='Full operation name (from predict response)', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID' + ], + ) + + +class Video(BaseModel): + bytesBase64Encoded: Optional[str] = Field( + None, description='Base64-encoded video content' + ) + gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video') + mimeType: Optional[str] = Field(None, description='Video MIME type') + + +class Error1(BaseModel): + code: Optional[int] = Field(None, description='Error code') + message: Optional[str] = Field(None, description='Error message') + + +class Response1(BaseModel): + field_type: Optional[str] = Field( + None, + alias='@type', + examples=[ + 'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse' + ], + ) + raiMediaFilteredCount: Optional[int] = Field( + None, description='Count of media filtered by responsible AI policies' + ) + raiMediaFilteredReasons: Optional[list[str]] = Field( + None, description='Reasons why media was filtered by responsible AI policies' + ) + videos: Optional[list[Video]] = None + + +class VeoGenVidPollResponse(BaseModel): + done: Optional[bool] = None + error: Optional[Error1] = Field( + None, description='Error details if operation failed' + ) + name: Optional[str] = None + response: Optional[Response1] = Field( + None, description='The actual prediction response if done is true' + ) diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 88b247d27977..2b17ce5380b6 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,28 +1,24 @@ -import logging import base64 -import aiohttp -import torch from io import BytesIO -from typing import Optional + from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO from comfy_api.input_impl.video_types import VideoFromFile -from comfy_api_nodes.apis import ( - VeoGenVidRequest, - VeoGenVidResponse, +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apis.veo_api import ( VeoGenVidPollRequest, VeoGenVidPollResponse, + VeoGenVidRequest, + VeoGenVidResponse, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, + download_url_to_video_output, + poll_op, + sync_op, + tensor_to_base64_string, ) -from comfy_api_nodes.util import downscale_image_tensor, tensor_to_base64_string - AVERAGE_DURATION_VIDEO_GEN = 32 MODELS_MAP = { "veo-2.0-generate-001": "veo-2.0-generate-001", @@ -32,28 +28,6 @@ "veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001", } -def convert_image_to_base64(image: torch.Tensor): - if image is None: - return None - - scaled_image = downscale_image_tensor(image, total_pixels=2048*2048) - return tensor_to_base64_string(scaled_image) - - -def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]: - if ( - poll_response.response - and hasattr(poll_response.response, "videos") - and poll_response.response.videos - and len(poll_response.response.videos) > 0 - ): - video = poll_response.response.videos[0] - else: - return None - if hasattr(video, "gcsUri") and video.gcsUri: - return str(video.gcsUri) - return None - class VeoVideoGenerationNode(IO.ComfyNode): """ @@ -166,18 +140,13 @@ async def execute( # Prepare the instances for the request instances = [] - instance = { - "prompt": prompt - } + instance = {"prompt": prompt} # Add image if provided if image is not None: - image_base64 = convert_image_to_base64(image) + image_base64 = tensor_to_base64_string(image) if image_base64: - instance["image"] = { - "bytesBase64Encoded": image_base64, - "mimeType": "image/png" - } + instance["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"} instances.append(instance) @@ -198,116 +167,74 @@ async def execute( if "veo-3.0" in model: parameters["generateAudio"] = generate_audio - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - # Initial request to start video generation - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=f"/proxy/veo/{model}/generate", - method=HttpMethod.POST, - request_model=VeoGenVidRequest, - response_model=VeoGenVidResponse - ), - request=VeoGenVidRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"), + response_model=VeoGenVidResponse, + data=VeoGenVidRequest( instances=instances, - parameters=parameters + parameters=parameters, ), - auth_kwargs=auth, ) - initial_response = await initial_operation.execute() - operation_name = initial_response.name - - logging.info("Veo generation started with operation name: %s", operation_name) - - # Define status extractor function def status_extractor(response): # Only return "completed" if the operation is done, regardless of success or failure # We'll check for errors after polling completes return "completed" if response.done else "pending" - # Define progress extractor function - def progress_extractor(response): - # Could be enhanced if the API provides progress information - return None - - # Define the polling operation - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/veo/{model}/poll", - method=HttpMethod.POST, - request_model=VeoGenVidPollRequest, - response_model=VeoGenVidPollResponse - ), - completed_statuses=["completed"], - failed_statuses=[], # No failed statuses, we'll handle errors after polling + poll_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"), + response_model=VeoGenVidPollResponse, status_extractor=status_extractor, - progress_extractor=progress_extractor, - request=VeoGenVidPollRequest( - operationName=operation_name + data=VeoGenVidPollRequest( + operationName=initial_response.name, ), - auth_kwargs=auth, poll_interval=5.0, - result_url_extractor=get_video_url_from_response, - node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_VIDEO_GEN, ) - # Execute the polling operation - poll_response = await poll_operation.execute() - # Now check for errors in the final response # Check for error in poll response - if hasattr(poll_response, 'error') and poll_response.error: - error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})" - logging.error(error_message) - raise Exception(error_message) + if poll_response.error: + raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})") # Check for RAI filtered content - if (hasattr(poll_response.response, 'raiMediaFilteredCount') and - poll_response.response.raiMediaFilteredCount > 0): + if ( + hasattr(poll_response.response, "raiMediaFilteredCount") + and poll_response.response.raiMediaFilteredCount > 0 + ): # Extract reason message if available - if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and - poll_response.response.raiMediaFilteredReasons): + if ( + hasattr(poll_response.response, "raiMediaFilteredReasons") + and poll_response.response.raiMediaFilteredReasons + ): reason = poll_response.response.raiMediaFilteredReasons[0] error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)" else: error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)" - logging.error(error_message) raise Exception(error_message) # Extract video data - if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0: + if ( + poll_response.response + and hasattr(poll_response.response, "videos") + and poll_response.response.videos + and len(poll_response.response.videos) > 0 + ): video = poll_response.response.videos[0] # Check if video is provided as base64 or URL - if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded: - # Decode base64 string to bytes - video_data = base64.b64decode(video.bytesBase64Encoded) - elif hasattr(video, 'gcsUri') and video.gcsUri: - # Download from URL - async with aiohttp.ClientSession() as session: - async with session.get(video.gcsUri) as video_response: - video_data = await video_response.content.read() - else: - raise Exception("Video returned but no data or URL was provided") - else: - raise Exception("Video generation completed but no video was returned") - - if not video_data: - raise Exception("No video data was returned") + if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded: + return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) - logging.info("Video generation completed successfully") + if hasattr(video, "gcsUri") and video.gcsUri: + return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) - # Convert video data to BytesIO object - video_io = BytesIO(video_data) - - # Return VideoFromFile object - return IO.NodeOutput(VideoFromFile(video_io)) + raise Exception("Video returned but no data or URL was provided") + raise Exception("Video generation completed but no video was returned") class Veo3VideoGenerationNode(VeoVideoGenerationNode): @@ -391,7 +318,10 @@ def define_schema(cls): IO.Combo.Input( "model", options=[ - "veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.0-generate-001", "veo-3.0-fast-generate-001" + "veo-3.1-generate", + "veo-3.1-fast-generate", + "veo-3.0-generate-001", + "veo-3.0-fast-generate-001", ], default="veo-3.0-generate-001", tooltip="Veo 3 model to use for video generation", @@ -424,5 +354,6 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: Veo3VideoGenerationNode, ] + async def comfy_entrypoint() -> VeoExtension: return VeoExtension() diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 1c37d16868d3..5833b118fdf6 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -136,6 +136,7 @@ async def poll_op( completed_statuses: Optional[list[Union[str, int]]] = None, failed_statuses: Optional[list[Union[str, int]]] = None, queued_statuses: Optional[list[Union[str, int]]] = None, + data: Optional[BaseModel] = None, poll_interval: float = 5.0, max_poll_attempts: int = 120, timeout_per_poll: float = 120.0, @@ -155,6 +156,7 @@ async def poll_op( completed_statuses=completed_statuses, failed_statuses=failed_statuses, queued_statuses=queued_statuses, + data=data, poll_interval=poll_interval, max_poll_attempts=max_poll_attempts, timeout_per_poll=timeout_per_poll, @@ -229,6 +231,7 @@ async def poll_op_raw( completed_statuses: Optional[list[Union[str, int]]] = None, failed_statuses: Optional[list[Union[str, int]]] = None, queued_statuses: Optional[list[Union[str, int]]] = None, + data: Optional[Union[dict[str, Any], BaseModel]] = None, poll_interval: float = 5.0, max_poll_attempts: int = 120, timeout_per_poll: float = 120.0, @@ -289,6 +292,7 @@ async def _ticker(): resp_json = await sync_op_raw( cls, poll_endpoint, + data=data, timeout=timeout_per_poll, max_retries=max_retries_per_poll, retry_delay=retry_delay_per_poll, From b81cfed5a06a0cbb61942cffde54b2338845459d Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Thu, 23 Oct 2025 16:00:23 +0300 Subject: [PATCH 8/8] fix(Veo3.1 model): take into account "generate_audio" parameter --- comfy_api_nodes/nodes_veo2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 2b17ce5380b6..d37e9e9b410a 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -164,7 +164,7 @@ async def execute( if seed > 0: parameters["seed"] = seed # Only add generateAudio for Veo 3 models - if "veo-3.0" in model: + if model.find("veo-2.0") == -1: parameters["generateAudio"] = generate_audio initial_response = await sync_op(