diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 96253f7726..f8dc6cdd32 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -100,7 +100,9 @@ def setup_inference(args, model): import habana_frameworks.torch.core as htcore print("Initializing inference mode") - htcore.hpu_initialize(model) + const_marking = os.getenv("ENABLE_CONST_MARKING", "True") + if const_marking == "True": + htcore.hpu_initialize(model) return model def setup_const_serialization(const_serialization_path):