| 
10 | 10 | 
 
  | 
11 | 11 | from runpod import create_pod, get_pod  | 
12 | 12 | from runpod.cli.utils.ssh_cmd import SSHConnection  | 
 | 13 | +from runpod import error as rp_error  | 
13 | 14 | from .helpers import get_project_pod  | 
14 | 15 | from ...utils.rp_sync import sync_directory  | 
15 | 16 | 
 
  | 
@@ -56,7 +57,7 @@ def create_new_project(project_name, runpod_volume_id, python_version,  | 
56 | 57 |             'uuid': str(uuid.uuid4())[:8],  # Short UUID  | 
57 | 58 |             'name': project_name,  | 
58 | 59 |             'base_image': 'runpod/base:0.0.0',  | 
59 |  | -            'gpu': 'NVIDIA RTX A4500',  | 
 | 60 | +            'gpu_types': ['NVIDIA RTX A4500'],  | 
60 | 61 |             'gpu_count': 1,  | 
61 | 62 |             'storage_id': runpod_volume_id,  | 
62 | 63 |             'volume_mount_path': '/runpod-volume',  | 
@@ -103,24 +104,42 @@ def launch_project():  | 
103 | 104 |         raise ValueError('Project pod already launched. Run "runpod project start" to start.')  | 
104 | 105 | 
 
  | 
105 | 106 |     print("Launching pod on RunPod...")  | 
106 |  | - | 
107 | 107 |     environment_variables = {"RUNPOD_PROJECT_ID": config["PROJECT"]["UUID"]}  | 
108 | 108 |     for variable in config['project']['env_vars']:  | 
109 | 109 |         environment_variables[variable] = config['project']['env_vars'][variable]  | 
110 |  | - | 
111 |  | -    new_pod = create_pod(  | 
112 |  | -        f'{config["PROJECT"]["Name"]}-dev ({config["PROJECT"]["UUID"]})',  | 
113 |  | -        config['PROJECT']['BaseImage'],  | 
114 |  | -        config['PROJECT']['GPU'],  | 
115 |  | -        gpu_count=int(config['PROJECT']['GPUCount']),  | 
116 |  | -        support_public_ip=True,  | 
117 |  | -        ports=f'{config["PROJECT"]["Ports"]}',  | 
118 |  | -        network_volume_id=f'{config["PROJECT"]["StorageID"]}',  | 
119 |  | -        volume_mount_path=f'{config["PROJECT"]["VolumeMountPath"]}',  | 
120 |  | -        container_disk_in_gb=int(config["PROJECT"]["ContainerDiskSizeGB"]),  | 
121 |  | -        env=environment_variables  | 
122 |  | -    )  | 
123 |  | - | 
 | 110 | +      | 
 | 111 | + | 
 | 112 | +    #supply as toml list of gpu types  | 
 | 113 | +    selected_gpu_types = config['PROJECT'].get('GPU_TYPES',[])  | 
 | 114 | +    #supply as comma-separated list of gpu types (deprecated)  | 
 | 115 | +    selected_gpu_types.extend(list(map(lambda s: s.strip(),config['PROJECT']['GPU'].split(','))) if 'GPU' in config['PROJECT'] else [])    | 
 | 116 | +    new_pod = None  | 
 | 117 | +    successful_gpu_type = None  | 
 | 118 | +    for gpu_type in selected_gpu_types:  | 
 | 119 | +        print(f"Trying to get a pod with {gpu_type}...")  | 
 | 120 | +        try:  | 
 | 121 | +            new_pod = create_pod(  | 
 | 122 | +                f'{config["PROJECT"]["Name"]}-dev ({config["PROJECT"]["UUID"]})',  | 
 | 123 | +                config['PROJECT']['BaseImage'],  | 
 | 124 | +                gpu_type,  | 
 | 125 | +                gpu_count=int(config['PROJECT']['GPUCount']),  | 
 | 126 | +                support_public_ip=True,  | 
 | 127 | +                ports=f'{config["PROJECT"]["Ports"]}',  | 
 | 128 | +                network_volume_id=f'{config["PROJECT"]["StorageID"]}',  | 
 | 129 | +                volume_mount_path=f'{config["PROJECT"]["VolumeMountPath"]}',  | 
 | 130 | +                container_disk_in_gb=int(config["PROJECT"]["ContainerDiskSizeGB"]),  | 
 | 131 | +                env={"RUNPOD_PROJECT_ID": config["PROJECT"]["UUID"]}  | 
 | 132 | +            )  | 
 | 133 | +            successful_gpu_type = gpu_type  | 
 | 134 | +            break  | 
 | 135 | +        except rp_error.QueryError:  | 
 | 136 | +            print(f"Couldn't obtain a {gpu_type}")  | 
 | 137 | +    if new_pod is None:  | 
 | 138 | +        print("Couldn't obtain any of the selected gpu types, try again later or use a different type")  | 
 | 139 | +        return  | 
 | 140 | +    print(f"Got a pod with {successful_gpu_type} ({new_pod['id']})")  | 
 | 141 | + | 
 | 142 | +    print("Waiting for pod to come online...")  | 
124 | 143 |     while new_pod.get('desiredStatus', None) != 'RUNNING' or new_pod.get('runtime', None) is None:  | 
125 | 144 |         new_pod = get_pod(new_pod['id'])  | 
126 | 145 | 
 
  | 
 | 
0 commit comments