Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/tetra_rp/core/api/runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ async def create_endpoint(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
instanceIds
activeBuildid
idePodId
templateId
}
}
"""
Expand All @@ -132,6 +133,91 @@ async def create_endpoint(self, input_data: Dict[str, Any]) -> Dict[str, Any]:

return endpoint_data

async def update_endpoint(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
mutation = """
mutation saveEndpoint($input: EndpointInput!) {
saveEndpoint(input: $input) {
aiKey
gpuIds
id
idleTimeout
locations
name
networkVolumeId
scalerType
scalerValue
templateId
type
userId
version
workersMax
workersMin
workersStandby
workersPFBTarget
gpuCount
allowedCudaVersions
executionTimeoutMs
instanceIds
activeBuildid
idePodId
}
}
"""

variables = {"input": input_data}

log.debug(
f"Updating endpoint with GraphQL: {input_data.get('name', 'unnamed')}"
)

result = await self._execute_graphql(mutation, variables)

if "saveEndpoint" not in result:
raise Exception("Unexpected GraphQL response structure")

endpoint_data = result["saveEndpoint"]
log.info(
f"Updated endpoint: {endpoint_data.get('id', 'unknown')} - {endpoint_data.get('name', 'unnamed')}"
)

return endpoint_data

async def update_template(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
mutation = """
mutation saveTemplate($input: SaveTemplateInput) {
saveTemplate(input: $input) {
id
containerDiskInGb
dockerArgs
env {
key
value
}
imageName
name
readme
}
}
"""

variables = {"input": input_data}

log.debug(
f"Updating template with GraphQL: {input_data.get('name', 'unnamed')}"
)

result = await self._execute_graphql(mutation, variables)

if "saveTemplate" not in result:
raise Exception("Unexpected GraphQL response structure")

template_data = result["saveTemplate"]
log.info(
f"Updated template: {template_data.get('id', 'unknown')} - {template_data.get('name', 'unnamed')}"
)

return template_data

async def get_cpu_types(self) -> Dict[str, Any]:
"""Get available CPU types."""
query = """
Expand Down
34 changes: 29 additions & 5 deletions src/tetra_rp/core/resources/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import hashlib
from abc import ABC, abstractmethod
from typing import Optional
from pydantic import BaseModel, ConfigDict

from typing import Optional, ClassVar
from pydantic import BaseModel, ConfigDict, computed_field

class BaseResource(BaseModel):
"""Base class for all resources."""
Expand All @@ -14,15 +13,35 @@ class BaseResource(BaseModel):
)

id: Optional[str] = None
_hashed_fields: ClassVar[set] = set()

# diffed fields is a temporary holder for fields that are "out of sync" -
# where a local instance representation of an endpoint is not up to date with the remote resource.
# it's needed for determining how updates are applied (eg, if we need to update a pod template)
fields_to_update: set[str] = set()


@computed_field
@property
def resource_id(self) -> str:
def resource_hash(self) -> str:
"""Unique resource ID based on configuration."""
resource_type = self.__class__.__name__
config_str = self.model_dump_json(exclude_none=True)
# don't self reference and exclude any deployment state (eg id)
config_str = self.model_dump_json(include=self.__class__._hashed_fields)
hash_obj = hashlib.md5(f"{resource_type}:{config_str}".encode())
return f"{resource_type}_{hash_obj.hexdigest()}"

@property
def resource_id(self) -> str:
"""Logical Tetra resource id defined by resource type and name.
Distinct from a server-side Runpod id.
"""
resource_type = self.__class__.__name__
# TODO: eventually we could namespace this to user ids or team ids
if not self.name:
self.name = "unnamed"
return f"{resource_type}_{self.name}"


class DeployableResource(BaseResource, ABC):
"""Base class for deployable resources."""
Expand All @@ -45,3 +64,8 @@ def is_deployed(self) -> bool:
async def deploy(self) -> "DeployableResource":
"""Deploy the resource."""
raise NotImplementedError("Subclasses should implement this method.")

@abstractmethod
async def update(self) -> "DeployableResource":
"""Update the resource."""
raise NotImplementedError("Subclasses should implement this method.")
4 changes: 4 additions & 0 deletions src/tetra_rp/core/resources/network_volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ async def _create_new_volume(self, client) -> "NetworkVolume":

raise ValueError("Deployment failed, no volume was created.")

async def update(self) -> "DeployableResource":
# TODO: impl
return self

async def deploy(self) -> "DeployableResource":
"""
Deploys the network volume resource using the provided configuration.
Expand Down
20 changes: 20 additions & 0 deletions src/tetra_rp/core/resources/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ async def get_or_deploy_resource(
async with resource_lock:
# Double-check pattern: check again inside the lock
if existing := self._resources.get(uid):
# if the old resource isn't actually deployed, then we can just deploy the new one
if not existing.is_deployed():
log.warning(f"{existing} is no longer valid, redeploying.")
self.remove_resource(uid)
Expand All @@ -109,6 +110,25 @@ async def get_or_deploy_resource(
self.add_resource(uid, deployed_resource)
return deployed_resource

# if the old resource is actually deployed, then we need to update it
if existing.resource_hash != config.resource_hash:
log.info(f"change in resource configuration detected, updating resource.")
for field in existing.__class__._hashed_fields:
existing_value, new_value = getattr(existing, field), getattr(config, field)
if existing_value != new_value:
log.debug(f"field: {field}, existing value: {getattr(existing, field)}, new value: {getattr(config, field)}")
config.fields_to_update.add(field)

# there are some fields that should be stored in pickled state and should be loaded back to the new obj
# these are used to make updates to platform endpoints/resources
# TODO: clean this up
await config.sync_config_with_deployed_resource(existing)
deployed_resource = await config.update()
self.remove_resource(uid)
self.add_resource(uid, deployed_resource)
return deployed_resource

# otherwise, nothing has changed and we just return what we already have
log.debug(f"{existing} exists, reusing.")
log.info(f"URL: {existing.url}")
return existing
Expand Down
118 changes: 111 additions & 7 deletions src/tetra_rp/core/resources/serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class ServerlessResource(DeployableResource):
Base class for GPU serverless resource
"""

# Fields marked as _input_only are excluded from gql requests to make the client impl simpler
_input_only = {
"id",
"cudaVersions",
Expand All @@ -70,6 +71,29 @@ class ServerlessResource(DeployableResource):
"flashboot",
"imageName",
"networkVolume",
"resource_hash",
"fields_to_update",
}

# hashed fields are fields that define configuration of an object. they are used for computing
# if a resource has changed and should only be mutable fields from the perspective of the platform.
# does not account for platform (Runpod) state fields (eg endpoint id) right now.
_hashed_fields = {
"datacenter",
"env",
"gpuIds",
"networkVolume",
"executionTimeoutMs",
"gpuCount",
"locations",
"name",
"networkVolumeId",
"scalerType",
"scalerValue",
"workersMax",
"workersMin",
"workersPFBTarget",
"allowedCudaVersions",
}

# === Input-only Fields ===
Expand All @@ -82,7 +106,7 @@ class ServerlessResource(DeployableResource):
datacenter: DataCenter = Field(default=DataCenter.EU_RO_1)

# === Input Fields ===
executionTimeoutMs: Optional[int] = None
executionTimeoutMs: Optional[int] = 0
gpuCount: Optional[int] = 1
idleTimeout: Optional[int] = 5
locations: Optional[str] = None
Expand All @@ -93,12 +117,12 @@ class ServerlessResource(DeployableResource):
templateId: Optional[str] = None
workersMax: Optional[int] = 3
workersMin: Optional[int] = 0
workersPFBTarget: Optional[int] = None
workersPFBTarget: Optional[int] = 0

# === Runtime Fields ===
activeBuildid: Optional[str] = None
aiKey: Optional[str] = None
allowedCudaVersions: Optional[str] = None
allowedCudaVersions: str = ""
computeType: Optional[str] = None
createdAt: Optional[str] = None # TODO: use datetime
gpuIds: Optional[str] = ""
Expand Down Expand Up @@ -143,7 +167,7 @@ def validate_gpus(cls, value: List[GpuGroup]) -> List[GpuGroup]:
@model_validator(mode="after")
def sync_input_fields(self):
"""Sync between temporary inputs and exported fields"""
if self.flashboot:
if self.flashboot and not self.name.endswith("-fb"):
self.name += "-fb"

# Sync datacenter to locations field for API
Expand All @@ -167,7 +191,10 @@ def sync_input_fields(self):

def _sync_input_fields_gpu(self):
# GPU-specific fields
if self.gpus:
# the response from the api for gpus is none
# apply this path only if gpuIds is None, otherwise we overwrite gpuIds
# with ANY gpu because the default for gpus is any
if self.gpus and not self.gpuIds:
# Convert gpus list to gpuIds string
self.gpuIds = ",".join(gpu.value for gpu in self.gpus)
elif self.gpuIds:
Expand Down Expand Up @@ -199,6 +226,43 @@ async def _ensure_network_volume_deployed(self) -> None:
deployedNetworkVolume = await self.networkVolume.deploy()
self.networkVolumeId = deployedNetworkVolume.id

async def _sync_graphql_object_with_inputs(self, returned_endpoint: "ServerlessResource"):
for _input_field in self._input_only:
if _input_field not in ["resource_hash"] and getattr(self, _input_field) is not None:
# sync input only fields stripped from gql request back to endpoint
setattr(returned_endpoint, _input_field, getattr(self, _input_field))

# assigning template info back to the object is needed for updating it in the future
returned_endpoint.template = self.template
if returned_endpoint.template:
returned_endpoint.template.id = returned_endpoint.templateId

return returned_endpoint

async def sync_config_with_deployed_resource(self, existing: "ServerlessResource") -> None:
self.id = existing.id
if not existing.template:
raise ValueError("Existing resource does not have a template, this is an invalid state. Update resources and try again")
self.template.id = existing.template.id

async def _update_template(self) -> "DeployableResource":
if not self.template:
raise ValueError("Tried to update a template that doesn't exist. Redeploy endpoint or attach a template to it")

try:
async with RunpodGraphQLClient() as client:
payload = self.template.model_dump(exclude={"resource_hash", "fields_to_update"}, exclude_none=True)
result = await client.update_template(payload)
if template := self.template.__class__(**result):
return template

raise ValueError("Deployment failed, no endpoint was returned.")

except Exception as e:
log.error(f"{self} failed to update: {e}")
raise


def is_deployed(self) -> bool:
"""
Checks if the serverless resource is deployed and available.
Expand Down Expand Up @@ -228,18 +292,58 @@ async def deploy(self) -> "DeployableResource":
await self._ensure_network_volume_deployed()

async with RunpodGraphQLClient() as client:
payload = self.model_dump(exclude=self._input_only, exclude_none=True)
# some "input only" fields are specific to tetra and not used in gql
exclude = {
f: ... for f in self._input_only} | {"template": {"resource_hash", "fields_to_update", "volumeInGb"}
} # TODO: maybe include this as a class attr
payload = self.model_dump(exclude=exclude, exclude_none=True)
result = await client.create_endpoint(payload)

# we need to merge the returned fields from gql with what the inputs are here
if endpoint := self.__class__(**result):
return endpoint
endpoint = await self._sync_graphql_object_with_inputs(endpoint)
return endpoint

raise ValueError("Deployment failed, no endpoint was returned.")

except Exception as e:
log.error(f"{self} failed to deploy: {e}")
raise

async def update(self) -> "DeployableResource":
# check if we need to update the template
# only update if the template exists already and there are fields to update for it
if self.template and self.fields_to_update & set(self.template.model_fields):
# we need to add the template id back here from hydrated state
log.debug(f"loaded template to update: {self.template.model_dump()}")
template = await self._update_template()
self.template = template

# if the only fields that need updated are template-only, just return now
if self.fields_to_update ^ set(template.model_fields):
log.debug("template-only update to endpoint complete")
return self

try:
async with RunpodGraphQLClient() as client:
exclude = {f: ... for f in self._input_only} | {"template": {"resource_hash", "fields_to_update", "volumeInGb", "id"}} # TODO: maybe include this as a class attr
# we need to include the id here so we update the existing endpoint
del exclude["id"]
payload = self.model_dump(exclude=exclude, exclude_none=True)
result = await client.update_endpoint(payload)

if endpoint := self.__class__(**result):
# TODO: should we check that the returned id = the input?
# we could "soft fail" and notify the user if we fall back to making a new endpoint
endpoint = await self._sync_graphql_object_with_inputs(endpoint)
return endpoint

raise ValueError("Update failed, no endpoint was returned.")

except Exception as e:
log.error(f"{self} failed to update: {e}")
raise

async def run_sync(self, payload: Dict[str, Any]) -> "JobOutput":
"""
Executes a serverless endpoint request with the payload.
Expand Down
1 change: 1 addition & 0 deletions src/tetra_rp/core/resources/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class PodTemplate(BaseResource):
name: Optional[str] = ""
ports: Optional[str] = ""
startScript: Optional[str] = ""
volumeInGb: Optional[int] = 20

@model_validator(mode="after")
def sync_input_fields(self):
Expand Down
Loading
Loading