diff --git a/build/utils.py b/build/utils.py index 52b4efee00..975b3c3803 100644 --- a/build/utils.py +++ b/build/utils.py @@ -5,7 +5,9 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations - +from typing import List +from pathlib import Path +import os import logging import torch @@ -33,6 +35,8 @@ def name_to_dtype(name): else: raise RuntimeError(f"unsupported dtype name {name} specified") +def allowable_dtype_names() -> List[str]: + return name_to_dtype_dict.keys() name_to_dtype_dict = { "fp32": torch.float, @@ -45,6 +49,18 @@ def name_to_dtype(name): "bfloat16": torch.bfloat16, } + +######################################################################### +### general model build utility functions for CLI ### + +def allowable_params_table() -> List[dtr]: + config_path = Path(f"{str(Path(__file__).parent)}/known_model_params") + known_model_params = [ + config.replace(".json", "") for config in os.listdir(config_path) + ] + return known_model_params + + ######################################################################### ### general model build utility functions ### diff --git a/cli.py b/cli.py index beac84c80f..a1656e928d 100644 --- a/cli.py +++ b/cli.py @@ -7,6 +7,8 @@ import json from pathlib import Path +from build.utils import allowable_dtype_names, allowable_params_table + import torch # CPU is always available and also exportable to ExecuTorch @@ -208,6 +210,7 @@ def add_arguments(parser): "-d", "--dtype", default="float32", + choices = allowable_dtype_names(), help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32", ) parser.add_argument( @@ -239,12 +242,14 @@ def add_arguments(parser): "--params-table", type=str, default=None, + choices=allowable_params_table(), help="Parameter table to use", ) parser.add_argument( "--device", type=str, default=default_device, + choices=["cpu", "cuda", "mps"], help="Hardware device to use. Options: cpu, cuda, mps", ) parser.add_argument(