Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 7 additions & 66 deletions src/tetra_rp/core/resources/serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,12 @@ def serialize_scaler_type(
return value.value if value is not None else None

@field_serializer("instanceIds")
def serialize_instance_ids(self, value: List[CpuInstanceType]) -> List[str]:
def serialize_instance_ids(
self, value: Optional[List[CpuInstanceType]]
) -> Optional[List[str]]:
"""Convert CpuInstanceType enums to strings."""
if value is None:
return None
return [item.value if hasattr(item, "value") else str(item) for item in value]

@field_validator("gpus")
Expand Down Expand Up @@ -247,62 +251,6 @@ async def deploy(self) -> "DeployableResource":
log.error(f"{self} failed to deploy: {e}")
raise

async def is_ready_for_requests(self, give_up_threshold=10) -> bool:
"""
Asynchronously checks if the serverless resource is ready to handle
requests by polling its health endpoint.

Args:
give_up_threshold (int, optional): The maximum number of polling
attempts before giving up and raising an error. Defaults to 10.

Returns:
bool: True if the serverless resource is ready for requests.

Raises:
ValueError: If the serverless resource is not deployed.
RuntimeError: If the health status is THROTTLED, UNHEALTHY, or UNKNOWN
after exceeding the give_up_threshold.
"""
if not self.is_deployed():
raise ValueError("Serverless is not deployed")

log.debug(f"{self} | API /health")

current_pace = 0
attempt = 0

# Poll for health status
while True:
await asyncio.sleep(current_pace)

health = await asyncio.to_thread(self.endpoint.health)
health = ServerlessHealth(**health)

if health.is_ready:
return True
else:
# nothing changed, increase the gap
attempt += 1
indicator = "." * (attempt // 2) if attempt % 2 == 0 else ""
if indicator:
log.info(f"{self} | {indicator}")

status = health.workers.status
if status in [
Status.THROTTLED,
Status.UNHEALTHY,
Status.UNKNOWN,
]:
log.debug(f"{self} | Health {status.value}")

if attempt >= give_up_threshold:
# Give up
raise RuntimeError(f"Health {status.value}")

# Adjust polling pace appropriately
current_pace = get_backoff_delay(attempt)

async def run_sync(self, payload: Dict[str, Any]) -> "JobOutput":
"""
Executes a serverless endpoint request with the payload.
Expand All @@ -319,9 +267,6 @@ def _fetch_job():
try:
# log.debug(f"[{log_group}] Payload: {payload}")

# Poll until requests can be sent
await self.is_ready_for_requests()

log.info(f"{self} | API /run_sync")
response = await asyncio.to_thread(_fetch_job)
return JobOutput(**response)
Expand All @@ -346,9 +291,6 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput":
try:
# log.debug(f"[{self}] Payload: {payload}")

# Poll until requests can be sent
await self.is_ready_for_requests()

# Create a job using the endpoint
log.info(f"{self} | API /run")
job = await asyncio.to_thread(self.endpoint.run, request_input=payload)
Expand All @@ -366,9 +308,8 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput":
while True:
await asyncio.sleep(current_pace)

if await self.is_ready_for_requests():
# Check job status
job_status = await asyncio.to_thread(job.status)
# Check job status
job_status = await asyncio.to_thread(job.status)

if last_status == job_status:
# nothing changed, increase the gap
Expand Down
Loading
Loading