diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 6b0b2e4695..2c89c63256 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -239,7 +239,17 @@ def setup_parser(parser): ) parser.add_argument("--temperature", default=1.0, type=float, help="Temperature value for text generation") parser.add_argument("--top_p", default=1.0, type=float, help="Top_p value for generating text via sampling") - + parser.add_argument( + "--const_serialization_path", + "--csp", + type=str, + help="Path to serialize const params. Const params will be held on disk memory instead of being allocated on host memory.", + ) + parser.add_argument( + "--disk_offload", + action="store_true", + help="Whether to enable device map auto. In case no space left on cpu, weights will be offloaded to disk.", + ) args = parser.parse_args() if args.torch_compile: @@ -561,6 +571,10 @@ def generate_dataset(batch): import habana_quantization_toolkit habana_quantization_toolkit.finish_measurements(model) + if args.const_serialization_path and os.path.isdir(args.const_serialization_path): + import shutil + + shutil.rmtree(args.const_serialization_path) if __name__ == "__main__": diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 4ae8dcb26c..2870e2cb0e 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -176,6 +176,10 @@ def main(): import habana_quantization_toolkit habana_quantization_toolkit.finish_measurements(model) + if args.const_serialization_path and os.path.isdir(args.const_serialization_path): + import shutil + + shutil.rmtree(args.const_serialization_path) if __name__ == "__main__": diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index e8c847c2f7..f395cf24cc 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -98,12 +98,9 @@ def setup_distributed(args): def setup_quantization(args, model): import habana_frameworks.torch.core as htcore - from habana_frameworks.torch.core.quantization import _check_params_as_const, _mark_params_as_const from habana_frameworks.torch.hpu import hpu print("Initializing inference with quantization") - _mark_params_as_const(model) - _check_params_as_const(model) if not args.quant_config: hpu.enable_quantization() htcore.hpu_initialize(model) @@ -373,6 +370,10 @@ def initialize_model(args, logger): "revision": args.model_revision, "token": args.token, } + if args.disk_offload: + model_kwargs["device_map"] = "auto" + model_kwargs["offload_folder"] = "/tmp/offload_folder/" + model = ( setup_model(args, model_dtype, model_kwargs, logger) if not use_deepspeed @@ -380,6 +381,16 @@ def initialize_model(args, logger): ) tokenizer, model = setup_tokenizer(args, model) generation_config = setup_generation_config(args, model, tokenizer) + + if args.const_serialization_path: + import uuid + + args.const_serialization_path = os.path.join(args.const_serialization_path + uuid.uuid4().hex) + os.makedirs(args.const_serialization_path) + from habana_frameworks.torch.hpu import enable_const_section_serialization + + print("Serializing const params to {}".format(args.const_serialization_path)) + enable_const_section_serialization(args.const_serialization_path, True) if args.fp8: model = setup_quantization(args, model) init_end = time.perf_counter()