diff --git a/src/tetra_rp/__init__.py b/src/tetra_rp/__init__.py index 0849a71..dce0129 100644 --- a/src/tetra_rp/__init__.py +++ b/src/tetra_rp/__init__.py @@ -23,6 +23,7 @@ ServerlessEndpoint, runpod, NetworkVolume, + FlashProject, ) @@ -40,4 +41,5 @@ "ServerlessEndpoint", "runpod", "NetworkVolume", + "FlashProject", ] diff --git a/src/tetra_rp/core/api/runpod.py b/src/tetra_rp/core/api/runpod.py index af1754d..778bf7f 100644 --- a/src/tetra_rp/core/api/runpod.py +++ b/src/tetra_rp/core/api/runpod.py @@ -6,7 +6,7 @@ import json import logging import os -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List import aiohttp @@ -202,6 +202,213 @@ async def delete_endpoint(self, endpoint_id: str) -> Dict[str, Any]: result = await self._execute_graphql(mutation, variables) return {"success": result.get("deleteEndpoint") is not None} + async def list_flash_projects(self) -> Dict[str, Any]: + """ + List all flash projects in Runpod. + """ + log.debug("Listing Flash projects") + query = """ + query getFlashProjects { + myself { + flashProjects { + id + name + environments { + id + name + } + } + } + } + """ + + result = await self._execute_graphql(query) + return result["myself"].get("flashProjects", []) + + async def prepare_artifact_upload(self, input_data: Dict[str, Any]) -> Dict[str, Any]: + mutation = """ + mutation PrepareArtifactUpload($input: PrepareFlashArtifactUploadInput!) { + prepareFlashArtifactUpload(input: $input) { + uploadUrl + objectKey + expiresAt + } + } + """ + variables = {"input": input_data} + + log.debug(f"Preparing upload url for flash environment: {input_data}") + + result = await self._execute_graphql(mutation, variables) + return result + + async def finalize_artifact_upload(self, input_data: Dict[str, Any]) -> Dict[str, Any]: + mutation = """ + mutation FinalizeArtifactUpload($input: FinalizeFlashArtifactUploadInput!) { + finalizeFlashArtifactUpload(input: $input) { + id + versionNumber + status + } + } + """ + variables = {"input": input_data} + + log.debug(f"finalizing upload for flash project: {input_data}") + + result = await self._execute_graphql(mutation, variables) + return result + + + + async def get_flash_project(self, input_data: Dict[str, Any]) -> Dict[str, Any]: + query = """ + query getFlashProject($input: String!) { + flashProject(projectId: $input) { + id + name + environments { + id + name + } + } + } + """ + variables = {"input": input_data} + + log.debug(f"Fetching flash project for input: {input_data}") + result = await self._execute_graphql(query, variables) + return result + + async def get_flash_project_by_name(self, project_name: str) -> Dict[str, Any]: + query = """ + query getFlashProjectByName($projectName: String!) { + flashProjectByName(projectName: $projectName) { + id + name + environments { + id + name + } + } + } + """ + variables = {"projectName": project_name} + + log.debug(f"Fetching flash project by name for input: {project_name}") + result = await self._execute_graphql(query, variables) + return result + + async def get_flash_environment(self, environment_id: str, requested_vars: Optional[List[str]] = None) -> Dict[str, Any]: + if not requested_vars: + requested_vars = ["id", "name"] + fragment = "\n".join(requested_vars) + query = f""" + query getFlashEnvironment($environmentId: String!) {{ + flashEnvironment(environmentId: $environmentId) {{ + {fragment} + }} + }} + """ + variables = {"environmentId": environment_id} + + log.debug(f"Fetching flash project by name for input: {variables}") + result = await self._execute_graphql(query, variables) + return result + + async def get_flash_environment_by_name(self, project_id: str, environment_name: str) -> Dict[str, Any]: + query = """ + query getFlashEnvironmentByName($environmentName: String!) { + flashEnvironmentByName(environmentName: $environmentName) { + id + name + } + } + """ + variables = {"flashProjectId": project_id, "name": environment_name} + + log.debug(f"Fetching flash project by name for input: {variables}") + result = await self._execute_graphql(query, variables) + return result + + async def get_flash_artifact_url(self, environment_id: str) -> Dict[str, Any]: + result = await self.get_flash_environment(environment_id, ["name", "activeArtifact { objectKey\ndownloadUrl }"]) + return result + + async def deploy_build_to_environment(self, input_data: Dict[str, Any]) -> Dict[str, Any]: + # TODO(jhcipar) should we not generate a presigned url when promoting a build here? + mutation = """ + mutation deployBuildToEnvironment($input: DeployBuildToEnvironmentInput!) { + deployBuildToEnvironment(input: $input) { + id + name + activeArtifact { + objectKey + downloadUrl + expiresAt + } + } + } + """ + + variables = {"input": input_data} + + log.debug( + f"Deploying flash environment with vars: {input_data}" + ) + + result = await self._execute_graphql(mutation, variables) + return result + + async def create_flash_project(self, input_data: Dict[str, Any]) -> Dict[str, Any]: + """Create a new flash project in Runpod. + """ + log.debug(f"creating flash project with name {input_data.get('name')}") + + mutation = """ + mutation createFlashProject($input: CreateFlashProjectInput!) { + createFlashProject(input: $input) { + id + name + } + } + """ + + variables = {"input": input_data} + + log.debug( + f"Creating flash project with GraphQL: {input_data.get('name', 'unnamed')}" + ) + + result = await self._execute_graphql(mutation, variables) + + return result + + async def create_flash_environment(self, input_data: Dict[str, Any]) -> Dict[str, Any]: + """Create an environment within a flash project. + """ + log.debug(f"creating flash environment with name {input_data.get('name')}") + + mutation = """ + mutation createFlashEnvironment($input: CreateFlashEnvironmentInput!) { + createFlashEnvironment(input: $input) { + id + name + } + } + """ + + variables = {"input": input_data} + + log.debug( + f"Creating flash environment with GraphQL: {input_data.get('name', 'unnamed')}" + ) + + result = await self._execute_graphql(mutation, variables) + + return result + + async def close(self): """Close the HTTP session.""" if self.session and not self.session.closed: @@ -305,6 +512,7 @@ async def list_network_volumes(self) -> Dict[str, Any]: return result + async def close(self): """Close the HTTP session.""" if self.session and not self.session.closed: diff --git a/src/tetra_rp/core/resources/__init__.py b/src/tetra_rp/core/resources/__init__.py index 10aa5bc..f99153e 100644 --- a/src/tetra_rp/core/resources/__init__.py +++ b/src/tetra_rp/core/resources/__init__.py @@ -13,6 +13,7 @@ from .serverless_cpu import CpuServerlessEndpoint from .template import PodTemplate from .network_volume import NetworkVolume, DataCenter +from .project import FlashProject __all__ = [ @@ -34,4 +35,5 @@ "ServerlessEndpoint", "PodTemplate", "NetworkVolume", + "FlashProject", ] diff --git a/src/tetra_rp/core/resources/project.py b/src/tetra_rp/core/resources/project.py new file mode 100644 index 0000000..aacb323 --- /dev/null +++ b/src/tetra_rp/core/resources/project.py @@ -0,0 +1,121 @@ +import pathlib +import requests +import asyncio +from typing import Dict, Callable, TYPE_CHECKING + +from ..api.runpod import RunpodGraphQLClient +from ..resources.resource_manager import ResourceManager + +if TYPE_CHECKING: + from . import ServerlessResource + +class FlashProject: + def __init__(self, name: str): + self.name: str = name + self.id: str = "" + self.resources: Dict[str, "ServerlessResource"] = {} + with asyncio.Runner() as runner: + runner.run(self._get_or_create_self()) + + def remote(self, *args, **kwargs): + from tetra_rp.client import remote as remote_decorator + + resource_config = kwargs.get("resource_config") + + if resource_config is None and args: + candidate = args[0] + if hasattr(candidate, "resource_id"): + self.resources[candidate.resource_id] = candidate + + return remote_decorator(*args, **kwargs) + + async def _get_or_create_self(self): + async with RunpodGraphQLClient() as client: + try: + result = await client.get_flash_project_by_name(self.name) + self.id = result["flashProjectByName"]["id"] + return result + except Exception as exc: + if not "project not found" in str(exc).lower(): + raise + result = await client.create_flash_project({"name": self.name}) + + self.id = result["createFlashProject"]["id"] + return result + + async def _get_id_by_name(self): + async with RunpodGraphQLClient() as client: + result = await client.get_flash_project_by_name(self.name) + if not result.get("flashProjectByName"): + raise ValueError("flash project not found", self.name) + return result["flashProjectByName"]["id"] + + async def create_environment(self, environment_name: str): + async with RunpodGraphQLClient() as client: + result = await client.create_flash_environment({"flashProjectId": self.id, "name": environment_name}) + return result["createFlashEnvironment"] + + + @staticmethod + async def list(): + async with RunpodGraphQLClient() as client: + return await client.list_flash_projects() + + async def _get_tarball_upload_url(self): + async with RunpodGraphQLClient() as client: + return await client.prepare_artifact_upload({"projectId": self.id}) + + async def _get_active_artifact(self, environment_id: str): + async with RunpodGraphQLClient() as client: + result = await client.get_flash_artifact_url(environment_id) + if not result["flashEnvironment"].get("activeArtifact"): + raise ValueError("No active artifact for environment id found", environment_id) + return result["flashEnvironment"]["activeArtifact"] + + async def deploy_build_to_environment(self, environment_id: str, build_id: str): + async with RunpodGraphQLClient() as client: + result = await client.deploy_build_to_environment({"environmentId": environment_id, "flashBuildId": build_id}) + return result + + async def download_tarball(self, environment_id: str, dest_file: str): + result = await self._get_active_artifact(environment_id) + url = result["downloadUrl"] + with open(dest_file, "wb") as stream: + with requests.get(url, stream=True) as resp: + resp.raise_for_status() + for chunk in resp.iter_content(): + if chunk: + stream.write(chunk) + + async def _finalize_tarball_upload(self, object_key: str): + async with RunpodGraphQLClient() as client: + result = await client.finalize_artifact_upload( + {"projectId": self.id, "objectKey": object_key} + ) + return result["finalizeFlashArtifactUpload"] + + async def upload_tarball(self, tar_path: str): + result = await self._get_tarball_upload_url() + url = result["prepareFlashArtifactUpload"]["uploadUrl"] + object_key = result["prepareFlashArtifactUpload"]["objectKey"] + + path = pathlib.Path(tar_path) + headers = {"Content-Type": "application/x-tar"} + + with path.open("rb") as fh: + resp = requests.put(url, data=fh, headers=headers) + + resp.raise_for_status() + resp = await self._finalize_tarball_upload(object_key) + return resp + + async def _deploy_in_environment(self, environment: str): + """ + Entrypoint for cpu sls endpoint to execute provisioning its registered resources. + Goes through all registered resources and gets or deploys them + Should update app env state as Ready at the end + TODO(jhcipar) should add flash env into resource identifiers + """ + resource_manager = ResourceManager() + for resource_id, resource in self.resources.items(): + await resource_manager.get_or_deploy_resource(resource)