diff --git a/runpod/api/ctl_commands.py b/runpod/api/ctl_commands.py index 1c43125c..d9a70161 100644 --- a/runpod/api/ctl_commands.py +++ b/runpod/api/ctl_commands.py @@ -61,7 +61,7 @@ def create_pod( gpu_count:int=1, volume_in_gb:int=0, container_disk_in_gb:int=5, min_vcpu_count:int=1, min_memory_in_gb:int=1, docker_args:str="", ports:Optional[str]=None, volume_mount_path:str="/workspace", - env:Optional[dict]=None + env:Optional[dict]=None, template_id:Optional[str]=None ) -> dict: ''' Create a pod @@ -79,6 +79,7 @@ def create_pod( :param env: the environment variables to inject into the pod, for example {EXAMPLE_VAR:"example_value", EXAMPLE_VAR2:"example_value 2"}, will inject EXAMPLE_VAR and EXAMPLE_VAR2 into the pod with the mentioned values + :param template_id: the id of the template to use for the pod :example: @@ -95,7 +96,7 @@ def create_pod( cloud_type, support_public_ip, data_center_id, country_code, gpu_count, volume_in_gb, container_disk_in_gb, min_vcpu_count, min_memory_in_gb, docker_args, - ports, volume_mount_path, env) + ports, volume_mount_path, env, template_id) ) cleaned_response = raw_response["data"]["podFindAndDeployOnDemand"] diff --git a/runpod/api/mutations/pods.py b/runpod/api/mutations/pods.py index a047f14a..eb8818ca 100644 --- a/runpod/api/mutations/pods.py +++ b/runpod/api/mutations/pods.py @@ -10,7 +10,7 @@ def generate_pod_deployment_mutation( data_center_id=None, country_code=None, gpu_count=None, volume_in_gb=None, container_disk_in_gb=None, min_vcpu_count=None, min_memory_in_gb=None, docker_args=None, ports=None, volume_mount_path=None, - env=None): + env=None, template_id=None): ''' Generates a mutation to deploy a pod on demand. ''' @@ -55,6 +55,8 @@ def generate_pod_deployment_mutation( env_string = ", ".join( [f'{{ key: "{key}", value: "{value}" }}' for key, value in env.items()]) input_fields.append(f"env: [{env_string}]") + if template_id is not None: + input_fields.append(f'templateId: "{template_id}"') # Format input fields diff --git a/tests/test_api/test_mutations_pods.py b/tests/test_api/test_mutations_pods.py index f4e1083e..a587872c 100644 --- a/tests/test_api/test_mutations_pods.py +++ b/tests/test_api/test_mutations_pods.py @@ -27,7 +27,8 @@ def test_generate_pod_deployment_mutation(self): ports="8080", volume_mount_path="/path", env={"ENV": "test"}, - support_public_ip=True) + support_public_ip=True, + template_id="abcde") # Here you should check the correct structure of the result self.assertIn("mutation", result)