diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index df37b10a30..78cc79a238 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -245,8 +245,12 @@ def setup_model(args, model_dtype, model_kwargs, logger): args.model_name_or_path, torch_dtype=model_dtype, quantization_config=quantization_config, **model_kwargs ) elif args.load_quantized_model_with_inc: - from neural_compressor.torch.quantization import load + #TODO: This will be removed in v1.19 Synapse release + #Override neural_compressor _load_remaining_pretrained_weight for the Transformer 4.45 release. + import neural_compressor.torch.algorithms.weight_only.save_load as nc_sl + nc_sl.WOQModelLoader._load_remaining_pretrained_weight = local_load_remaining_pretrained_weight + from neural_compressor.torch.quantization import load model = load(model_name_or_path=args.model_name_or_path, format="huggingface", device="hpu", **model_kwargs) elif args.local_quantized_inc_model_path: org_model = AutoModelForCausalLM.from_pretrained( @@ -662,3 +666,46 @@ def initialize_model(args, logger): logger.info(f"device: {args.device}, n_hpu: {args.world_size}, bf16: {model_dtype == torch.bfloat16}") logger.info(f"Model initialization took {(init_end - init_start):.3f}s") return model, assistant_model, tokenizer, generation_config + +#TODO:This will be removed from Synapse v1.19 release. +#This is to override _load_remaining_pretrained_weight for Transformer 4.45 release. +def local_load_remaining_pretrained_weight(self,model): + from transformers.modeling_utils import _load_state_dict_into_meta_model, load_state_dict + + resolved_archive_file = self.kwargs.pop("resolved_archive_file", None) + torch_dtype = self.kwargs.pop("torch_dtype", torch.float32) + dtype_orig = self.kwargs.pop("dtype_orig", None) + offload_folder = self.kwargs.pop("offload_folder", None) + offload_state_dict = self.kwargs.pop("offload_state_dict", False) + + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + if not isinstance(resolved_archive_file, list): + resolved_archive_file = [resolved_archive_file] + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file) + + params_dict={ + "model": model, + "state_dict": state_dict, + "start_prefix": "", + "expected_keys": list(state_dict.keys()), + "device_map": {"": self.device}, + "offload_folder": offload_folder, + "state_dict_folder": tempfile.mkdtemp() if offload_state_dict else None, + "state_dict_index": {} if offload_state_dict else None, + "dtype": torch_dtype, + "keep_in_fp32_modules": [], + } + + _load_state_dict_into_meta_model(**params_dict) + + # make sure token embedding weights are still tied if needed + model.tie_weights() + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + return model