Skip to content

Commit

Permalink
cli (pytorch#398)
Browse files Browse the repository at this point in the history
* cli

* typos
  • Loading branch information
mikekgfb authored and malfet committed Jul 17, 2024
1 parent 9c98d75 commit 7be5645
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
18 changes: 17 additions & 1 deletion build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import List
from pathlib import Path
import os
import logging

import torch
Expand Down Expand Up @@ -33,6 +35,8 @@ def name_to_dtype(name):
else:
raise RuntimeError(f"unsupported dtype name {name} specified")

def allowable_dtype_names() -> List[str]:
return name_to_dtype_dict.keys()

name_to_dtype_dict = {
"fp32": torch.float,
Expand All @@ -45,6 +49,18 @@ def name_to_dtype(name):
"bfloat16": torch.bfloat16,
}


#########################################################################
### general model build utility functions for CLI ###

def allowable_params_table() -> List[dtr]:
config_path = Path(f"{str(Path(__file__).parent)}/known_model_params")
known_model_params = [
config.replace(".json", "") for config in os.listdir(config_path)
]
return known_model_params


#########################################################################
### general model build utility functions ###

Expand Down
5 changes: 5 additions & 0 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import json
from pathlib import Path

from build.utils import allowable_dtype_names, allowable_params_table

import torch

# CPU is always available and also exportable to ExecuTorch
Expand Down Expand Up @@ -208,6 +210,7 @@ def add_arguments(parser):
"-d",
"--dtype",
default="float32",
choices = allowable_dtype_names(),
help="Override the dtype of the model (default is the checkpoint dtype). Options: bf16, fp16, fp32",
)
parser.add_argument(
Expand Down Expand Up @@ -239,12 +242,14 @@ def add_arguments(parser):
"--params-table",
type=str,
default=None,
choices=allowable_params_table(),
help="Parameter table to use",
)
parser.add_argument(
"--device",
type=str,
default=default_device,
choices=["cpu", "cuda", "mps"],
help="Hardware device to use. Options: cpu, cuda, mps",
)
parser.add_argument(
Expand Down

0 comments on commit 7be5645

Please sign in to comment.