Skip to content

Commit 73abaa8

Browse files
mikekgfbmalfet
authored andcommitted
Strict commandline (pytorch#157)
* add strict commandline enforcement * strict commandline enforcement capability * commandline enforcement vs warning * fix typo * typo
1 parent 33928a8 commit 73abaa8

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

cli.py

+27
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,33 @@
1313

1414
default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'
1515

16+
strict = False
17+
18+
def check_args(args, command_name: str):
19+
global strict
20+
21+
# chat and generate support the same options
22+
if command_name in ["generate", "chat", "gui"]:
23+
# examples, can add more. Note that attributes convert dash to _
24+
disallowed_args = ['output_pte_path', 'output_dso_path' ]
25+
elif command_name == "export":
26+
# examples, can add more. Note that attributes convert dash to _
27+
disallowed_args = ['pte_path', 'dso_path' ]
28+
elif command_name == "eval":
29+
# TBD
30+
disallowed_args = []
31+
else:
32+
raise RuntimeError(f"{command_name} is not a valid command")
33+
34+
for disallowed in disallowed_args:
35+
if args.hasattr(disallow):
36+
text = f"command {command_name} does not support option {disallowed.replace('_', '-')}"
37+
if strict:
38+
raise RuntimeError(text)
39+
else:
40+
print(f"Warning: {text}")
41+
42+
1643
def cli_args():
1744
import argparse
1845

eval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
except:
2929
lm_eval_available = False
3030

31-
from generate import _load_model, encode_tokens, model_forward
31+
from generate import _load_inference_model, encode_tokens, model_forward
3232

3333
if lm_eval_available:
3434
try: # lm_eval version 0.4

torchat.py

+2
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ def cli():
2323
args = cli_args()
2424

2525
if args.generate or args.chat:
26+
check_args(args, "generate")
2627
generate_main(args)
2728
elif args.eval:
2829
eval_main(args)
2930
elif args.export:
31+
check_args(args, "export")
3032
export_main(args)
3133
else:
3234
raise RuntimeError("must specify either --generate or --export")

0 commit comments

Comments
 (0)