|
18 | 18 | SPOT_AVAILABILITY_ENDPOINT = f'{BASE_URL}/v2/spot/availability' |
19 | 19 | DEFAULT_CREDENTIALS_PATH = os.path.expanduser('~/.flow/config.yaml') |
20 | 20 |
|
| 21 | + |
| 22 | +def get_output_path() -> str: |
| 23 | + """Get output path for catalog file.""" |
| 24 | + current_dir = os.getcwd() |
| 25 | + if os.path.basename(current_dir) == 'mithril': |
| 26 | + return 'vms.csv' |
| 27 | + mithril_dir = os.path.join(current_dir, 'mithril') |
| 28 | + os.makedirs(mithril_dir, exist_ok=True) |
| 29 | + return os.path.join(mithril_dir, 'vms.csv') |
| 30 | + |
| 31 | + |
21 | 32 | # GPU memory mapping (in MiB) |
22 | 33 | GPU_MEMORY_MAP = { |
23 | 34 | 'A100': 40960, # 40 GB |
@@ -173,17 +184,17 @@ def fetch_spot_availability(api_key: str) -> Dict[str, List[Dict[str, Any]]]: |
173 | 184 | return availability |
174 | 185 |
|
175 | 186 |
|
176 | | -def create_catalog(output_path: str, api_key: Optional[str] = None) -> None: |
| 187 | +def create_catalog(api_key: Optional[str] = None) -> None: |
177 | 188 | """Create Mithril catalog CSV file.""" |
178 | 189 | print('Fetching Mithril instance types...') |
179 | 190 | api_key = get_api_key(api_key) |
180 | 191 | instance_types = fetch_instance_types(api_key) |
181 | 192 | availability = fetch_spot_availability(api_key) |
182 | 193 |
|
183 | 194 | print(f'Found {len(instance_types)} instance types') |
184 | | - print(f'Writing catalog to {output_path}') |
185 | 195 |
|
186 | | - os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| 196 | + output_path = get_output_path() |
| 197 | + print(f'Writing catalog to {output_path}') |
187 | 198 |
|
188 | 199 | # Track unique instance type names to handle duplicates |
189 | 200 | # Mithril API can return multiple configs with the same name |
@@ -259,17 +270,13 @@ def main(): |
259 | 270 | """Main entry point.""" |
260 | 271 | parser = argparse.ArgumentParser( |
261 | 272 | description='Fetch Mithril Cloud catalog data') |
262 | | - parser.add_argument('--output', |
263 | | - type=str, |
264 | | - default='~/.sky/catalogs/v8/mithril/vms.csv', |
265 | | - help='Output path for the catalog CSV file') |
266 | 273 | parser.add_argument('--api-key', type=str, help='Mithril API key') |
267 | 274 |
|
268 | 275 | args = parser.parse_args() |
269 | | - output_path = os.path.expanduser(args.output) |
270 | 276 |
|
271 | 277 | try: |
272 | | - create_catalog(output_path, args.api_key) |
| 278 | + create_catalog(args.api_key) |
| 279 | + print(f'Mithril Service Catalog saved to {get_output_path()}') |
273 | 280 | return 0 |
274 | 281 | except Exception as e: # pylint: disable=broad-except |
275 | 282 | print(f'Error: {e}') |
|
0 commit comments