diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 610ecf1a8b..e46096c293 100755 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -335,6 +335,11 @@ def __call__(self, parser, namespace, values, option_string=None): default="none", help="Run multi card with the specified distributed strategy. Choices are 'tp' for Tensor Parallel Strategy or 'none'.", ) + parser.add_argument( + "--load_cp", + action="store_true", + help="Whether to load model from hugging face checkpoint.", + ) args = parser.parse_args() diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 67afcc7015..0a860f9bc9 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -242,6 +242,14 @@ def setup_model(args, model_dtype, model_kwargs, logger): model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, torch_dtype=model_dtype, quantization_config=quantization_config, **model_kwargs ) + elif args.load_cp: + from neural_compressor.torch.quantization import load + model = load( + model_name_or_path=args.model_name_or_path, + format="huggingface", + device="hpu", + **model_kwargs + ) else: if args.assistant_model is not None: assistant_model = AutoModelForCausalLM.from_pretrained( @@ -638,6 +646,9 @@ def initialize_model(args, logger): "token": args.token, "trust_remote_code": args.trust_remote_code, } + if args.load_cp: + model_kwargs["torch_dtype"] = torch.bfloat16 + if args.trust_remote_code: logger.warning("`trust_remote_code` is set, there is no guarantee this model works properly and it may fail")