Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/tetra_rp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ServerlessEndpoint,
runpod,
NetworkVolume,
FlashProject,
)


Expand All @@ -40,4 +41,5 @@
"ServerlessEndpoint",
"runpod",
"NetworkVolume",
"FlashProject",
]
210 changes: 209 additions & 1 deletion src/tetra_rp/core/api/runpod.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import logging
import os
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, List

import aiohttp

Expand Down Expand Up @@ -202,6 +202,213 @@ async def delete_endpoint(self, endpoint_id: str) -> Dict[str, Any]:
result = await self._execute_graphql(mutation, variables)
return {"success": result.get("deleteEndpoint") is not None}

async def list_flash_projects(self) -> Dict[str, Any]:
"""
List all flash projects in Runpod.
"""
log.debug("Listing Flash projects")
query = """
query getFlashProjects {
myself {
flashProjects {
id
name
environments {
id
name
}
}
}
}
"""

result = await self._execute_graphql(query)
return result["myself"].get("flashProjects", [])

async def prepare_artifact_upload(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
mutation = """
mutation PrepareArtifactUpload($input: PrepareFlashArtifactUploadInput!) {
prepareFlashArtifactUpload(input: $input) {
uploadUrl
objectKey
expiresAt
}
}
"""
variables = {"input": input_data}

log.debug(f"Preparing upload url for flash environment: {input_data}")

result = await self._execute_graphql(mutation, variables)
return result

async def finalize_artifact_upload(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
mutation = """
mutation FinalizeArtifactUpload($input: FinalizeFlashArtifactUploadInput!) {
finalizeFlashArtifactUpload(input: $input) {
id
versionNumber
status
}
}
"""
variables = {"input": input_data}

log.debug(f"finalizing upload for flash project: {input_data}")

result = await self._execute_graphql(mutation, variables)
return result



async def get_flash_project(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
query = """
query getFlashProject($input: String!) {
flashProject(projectId: $input) {
id
name
environments {
id
name
}
}
}
"""
variables = {"input": input_data}

log.debug(f"Fetching flash project for input: {input_data}")
result = await self._execute_graphql(query, variables)
return result

async def get_flash_project_by_name(self, project_name: str) -> Dict[str, Any]:
query = """
query getFlashProjectByName($projectName: String!) {
flashProjectByName(projectName: $projectName) {
id
name
environments {
id
name
}
}
}
"""
variables = {"projectName": project_name}

log.debug(f"Fetching flash project by name for input: {project_name}")
result = await self._execute_graphql(query, variables)
return result

async def get_flash_environment(self, environment_id: str, requested_vars: Optional[List[str]] = None) -> Dict[str, Any]:
if not requested_vars:
requested_vars = ["id", "name"]
fragment = "\n".join(requested_vars)
query = f"""
query getFlashEnvironment($environmentId: String!) {{
flashEnvironment(environmentId: $environmentId) {{
{fragment}
}}
}}
"""
variables = {"environmentId": environment_id}

log.debug(f"Fetching flash project by name for input: {variables}")
result = await self._execute_graphql(query, variables)
return result

async def get_flash_environment_by_name(self, project_id: str, environment_name: str) -> Dict[str, Any]:
query = """
query getFlashEnvironmentByName($environmentName: String!) {
flashEnvironmentByName(environmentName: $environmentName) {
id
name
}
}
"""
variables = {"flashProjectId": project_id, "name": environment_name}

log.debug(f"Fetching flash project by name for input: {variables}")
result = await self._execute_graphql(query, variables)
return result

async def get_flash_artifact_url(self, environment_id: str) -> Dict[str, Any]:
result = await self.get_flash_environment(environment_id, ["name", "activeArtifact { objectKey\ndownloadUrl }"])
return result

async def deploy_build_to_environment(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
# TODO(jhcipar) should we not generate a presigned url when promoting a build here?
mutation = """
mutation deployBuildToEnvironment($input: DeployBuildToEnvironmentInput!) {
deployBuildToEnvironment(input: $input) {
id
name
activeArtifact {
objectKey
downloadUrl
expiresAt
}
}
}
"""

variables = {"input": input_data}

log.debug(
f"Deploying flash environment with vars: {input_data}"
)

result = await self._execute_graphql(mutation, variables)
return result

async def create_flash_project(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
"""Create a new flash project in Runpod.
"""
log.debug(f"creating flash project with name {input_data.get('name')}")

mutation = """
mutation createFlashProject($input: CreateFlashProjectInput!) {
createFlashProject(input: $input) {
id
name
}
}
"""

variables = {"input": input_data}

log.debug(
f"Creating flash project with GraphQL: {input_data.get('name', 'unnamed')}"
)

result = await self._execute_graphql(mutation, variables)

return result

async def create_flash_environment(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
"""Create an environment within a flash project.
"""
log.debug(f"creating flash environment with name {input_data.get('name')}")

mutation = """
mutation createFlashEnvironment($input: CreateFlashEnvironmentInput!) {
createFlashEnvironment(input: $input) {
id
name
}
}
"""

variables = {"input": input_data}

log.debug(
f"Creating flash environment with GraphQL: {input_data.get('name', 'unnamed')}"
)

result = await self._execute_graphql(mutation, variables)

return result


async def close(self):
"""Close the HTTP session."""
if self.session and not self.session.closed:
Expand Down Expand Up @@ -305,6 +512,7 @@ async def list_network_volumes(self) -> Dict[str, Any]:

return result


async def close(self):
"""Close the HTTP session."""
if self.session and not self.session.closed:
Expand Down
2 changes: 2 additions & 0 deletions src/tetra_rp/core/resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .serverless_cpu import CpuServerlessEndpoint
from .template import PodTemplate
from .network_volume import NetworkVolume, DataCenter
from .project import FlashProject


__all__ = [
Expand All @@ -34,4 +35,5 @@
"ServerlessEndpoint",
"PodTemplate",
"NetworkVolume",
"FlashProject",
]
121 changes: 121 additions & 0 deletions src/tetra_rp/core/resources/project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import pathlib
import requests
import asyncio
from typing import Dict, Callable, TYPE_CHECKING

from ..api.runpod import RunpodGraphQLClient
from ..resources.resource_manager import ResourceManager

if TYPE_CHECKING:
from . import ServerlessResource

class FlashProject:
def __init__(self, name: str):
self.name: str = name
self.id: str = ""
self.resources: Dict[str, "ServerlessResource"] = {}
with asyncio.Runner() as runner:
runner.run(self._get_or_create_self())

def remote(self, *args, **kwargs):
from tetra_rp.client import remote as remote_decorator

resource_config = kwargs.get("resource_config")

if resource_config is None and args:
candidate = args[0]
if hasattr(candidate, "resource_id"):
self.resources[candidate.resource_id] = candidate

return remote_decorator(*args, **kwargs)

async def _get_or_create_self(self):
async with RunpodGraphQLClient() as client:
try:
result = await client.get_flash_project_by_name(self.name)
self.id = result["flashProjectByName"]["id"]
return result
except Exception as exc:
if not "project not found" in str(exc).lower():
raise
result = await client.create_flash_project({"name": self.name})

self.id = result["createFlashProject"]["id"]
return result

async def _get_id_by_name(self):
async with RunpodGraphQLClient() as client:
result = await client.get_flash_project_by_name(self.name)
if not result.get("flashProjectByName"):
raise ValueError("flash project not found", self.name)
return result["flashProjectByName"]["id"]

async def create_environment(self, environment_name: str):
async with RunpodGraphQLClient() as client:
result = await client.create_flash_environment({"flashProjectId": self.id, "name": environment_name})
return result["createFlashEnvironment"]


@staticmethod
async def list():
async with RunpodGraphQLClient() as client:
return await client.list_flash_projects()

async def _get_tarball_upload_url(self):
async with RunpodGraphQLClient() as client:
return await client.prepare_artifact_upload({"projectId": self.id})

async def _get_active_artifact(self, environment_id: str):
async with RunpodGraphQLClient() as client:
result = await client.get_flash_artifact_url(environment_id)
if not result["flashEnvironment"].get("activeArtifact"):
raise ValueError("No active artifact for environment id found", environment_id)
return result["flashEnvironment"]["activeArtifact"]

async def deploy_build_to_environment(self, environment_id: str, build_id: str):
async with RunpodGraphQLClient() as client:
result = await client.deploy_build_to_environment({"environmentId": environment_id, "flashBuildId": build_id})
return result

async def download_tarball(self, environment_id: str, dest_file: str):
result = await self._get_active_artifact(environment_id)
url = result["downloadUrl"]
with open(dest_file, "wb") as stream:
with requests.get(url, stream=True) as resp:
resp.raise_for_status()
for chunk in resp.iter_content():
if chunk:
stream.write(chunk)

async def _finalize_tarball_upload(self, object_key: str):
async with RunpodGraphQLClient() as client:
result = await client.finalize_artifact_upload(
{"projectId": self.id, "objectKey": object_key}
)
return result["finalizeFlashArtifactUpload"]

async def upload_tarball(self, tar_path: str):
result = await self._get_tarball_upload_url()
url = result["prepareFlashArtifactUpload"]["uploadUrl"]
object_key = result["prepareFlashArtifactUpload"]["objectKey"]

path = pathlib.Path(tar_path)
headers = {"Content-Type": "application/x-tar"}

with path.open("rb") as fh:
resp = requests.put(url, data=fh, headers=headers)

resp.raise_for_status()
resp = await self._finalize_tarball_upload(object_key)
return resp

async def _deploy_in_environment(self, environment: str):
"""
Entrypoint for cpu sls endpoint to execute provisioning its registered resources.
Goes through all registered resources and gets or deploys them
Should update app env state as Ready at the end
TODO(jhcipar) should add flash env into resource identifiers
"""
resource_manager = ResourceManager()
for resource_id, resource in self.resources.items():
await resource_manager.get_or_deploy_resource(resource)
Loading