Skip to content

Commit 7dfb8c5

Browse files
Merge pull request #177 from DireLines/tryGpuTypesInOrder
feat: try multiple gpu types in order
2 parents a2e7091 + 9a541a4 commit 7dfb8c5

File tree

1 file changed

+35
-16
lines changed

1 file changed

+35
-16
lines changed

runpod/cli/groups/project/functions.py

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from runpod import create_pod, get_pod
1212
from runpod.cli.utils.ssh_cmd import SSHConnection
13+
from runpod import error as rp_error
1314
from .helpers import get_project_pod
1415
from ...utils.rp_sync import sync_directory
1516

@@ -56,7 +57,7 @@ def create_new_project(project_name, runpod_volume_id, python_version,
5657
'uuid': str(uuid.uuid4())[:8], # Short UUID
5758
'name': project_name,
5859
'base_image': 'runpod/base:0.0.0',
59-
'gpu': 'NVIDIA RTX A4500',
60+
'gpu_types': ['NVIDIA RTX A4500'],
6061
'gpu_count': 1,
6162
'storage_id': runpod_volume_id,
6263
'volume_mount_path': '/runpod-volume',
@@ -103,24 +104,42 @@ def launch_project():
103104
raise ValueError('Project pod already launched. Run "runpod project start" to start.')
104105

105106
print("Launching pod on RunPod...")
106-
107107
environment_variables = {"RUNPOD_PROJECT_ID": config["PROJECT"]["UUID"]}
108108
for variable in config['project']['env_vars']:
109109
environment_variables[variable] = config['project']['env_vars'][variable]
110-
111-
new_pod = create_pod(
112-
f'{config["PROJECT"]["Name"]}-dev ({config["PROJECT"]["UUID"]})',
113-
config['PROJECT']['BaseImage'],
114-
config['PROJECT']['GPU'],
115-
gpu_count=int(config['PROJECT']['GPUCount']),
116-
support_public_ip=True,
117-
ports=f'{config["PROJECT"]["Ports"]}',
118-
network_volume_id=f'{config["PROJECT"]["StorageID"]}',
119-
volume_mount_path=f'{config["PROJECT"]["VolumeMountPath"]}',
120-
container_disk_in_gb=int(config["PROJECT"]["ContainerDiskSizeGB"]),
121-
env=environment_variables
122-
)
123-
110+
111+
112+
#supply as toml list of gpu types
113+
selected_gpu_types = config['PROJECT'].get('GPU_TYPES',[])
114+
#supply as comma-separated list of gpu types (deprecated)
115+
selected_gpu_types.extend(list(map(lambda s: s.strip(),config['PROJECT']['GPU'].split(','))) if 'GPU' in config['PROJECT'] else [])
116+
new_pod = None
117+
successful_gpu_type = None
118+
for gpu_type in selected_gpu_types:
119+
print(f"Trying to get a pod with {gpu_type}...")
120+
try:
121+
new_pod = create_pod(
122+
f'{config["PROJECT"]["Name"]}-dev ({config["PROJECT"]["UUID"]})',
123+
config['PROJECT']['BaseImage'],
124+
gpu_type,
125+
gpu_count=int(config['PROJECT']['GPUCount']),
126+
support_public_ip=True,
127+
ports=f'{config["PROJECT"]["Ports"]}',
128+
network_volume_id=f'{config["PROJECT"]["StorageID"]}',
129+
volume_mount_path=f'{config["PROJECT"]["VolumeMountPath"]}',
130+
container_disk_in_gb=int(config["PROJECT"]["ContainerDiskSizeGB"]),
131+
env={"RUNPOD_PROJECT_ID": config["PROJECT"]["UUID"]}
132+
)
133+
successful_gpu_type = gpu_type
134+
break
135+
except rp_error.QueryError:
136+
print(f"Couldn't obtain a {gpu_type}")
137+
if new_pod is None:
138+
print("Couldn't obtain any of the selected gpu types, try again later or use a different type")
139+
return
140+
print(f"Got a pod with {successful_gpu_type} ({new_pod['id']})")
141+
142+
print("Waiting for pod to come online...")
124143
while new_pod.get('desiredStatus', None) != 'RUNNING' or new_pod.get('runtime', None) is None:
125144
new_pod = get_pod(new_pod['id'])
126145

0 commit comments

Comments
 (0)