@@ -325,9 +325,8 @@ def main(
325325 max_new_tokens : int = 100 ,
326326 top_k : int = 200 ,
327327 temperature : float = 0.8 ,
328- checkpoint_path : Path = Path (
329- "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"
330- ),
328+ checkpoint_path : Optional [Path ] = None ,
329+ tokenizer_path : Optional [Path ] = None ,
331330 compile : bool = True ,
332331 compile_prefill : bool = False ,
333332 profile : Optional [Path ] = None ,
@@ -339,14 +338,21 @@ def main(
339338 quantize = None ,
340339) -> None :
341340 """Generates text samples based on a pre-trained Transformer model and tokenizer."""
342- assert checkpoint_path .is_file (), checkpoint_path
343-
344- torch .manual_seed (1234 )
345-
346- tokenizer_path = checkpoint_path .parent / "tokenizer.model"
341+ assert (
342+ (checkpoint_path and checkpoint_path .is_file ()) or
343+ (dso_path and Path (dso_path ).is_file ()) or
344+ (pte_path and Path (pte_path ).is_file ())
345+ ), "need to specified a valid checkpoint path, DSO path, or PTE path"
346+ assert not (dso_path and pte_path ), "specify either DSO path or PTE path, but not both"
347+
348+ if (checkpoint_path and (dso_path or pte_path )):
349+ print ("Warning: checkpoint path ignored because an exported DSO or PTE path specified" )
350+
351+ if not tokenizer_path :
352+ tokenizer_path = checkpoint_path .parent / "tokenizer.model"
347353 assert tokenizer_path .is_file (), tokenizer_path
348354
349- global print
355+ # global print
350356 # from tp import maybe_init_dist
351357 # rank = maybe_init_dist()
352358 use_tp = False
@@ -540,10 +546,22 @@ def cli():
540546 parser .add_argument (
541547 "--temperature" , type = float , default = 0.8 , help = "Temperature for sampling."
542548 )
549+ parser .add_argument (
550+ "--seed" ,
551+ type = int ,
552+ default = 1234 , # set None for release
553+ help = "Initialize torch seed"
554+ )
543555 parser .add_argument (
544556 "--checkpoint-path" ,
545557 type = Path ,
546- default = Path ("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth" ),
558+ default = None ,
559+ help = "Model checkpoint path." ,
560+ )
561+ parser .add_argument (
562+ "--tokenizer-path" ,
563+ type = Path ,
564+ default = None ,
547565 help = "Model checkpoint path." ,
548566 )
549567 parser .add_argument (
@@ -590,6 +608,10 @@ def cli():
590608
591609
592610 args = parser .parse_args ()
611+
612+ if args .seed :
613+ torch .manual_seed (args .seed )
614+
593615 main (
594616 args .prompt ,
595617 args .interactive ,
@@ -598,6 +620,7 @@ def cli():
598620 args .top_k ,
599621 args .temperature ,
600622 args .checkpoint_path ,
623+ args .tokenizer_path ,
601624 args .compile ,
602625 args .compile_prefill ,
603626 args .profile ,
0 commit comments