diff --git a/optimum/habana/accelerate/accelerator.py b/optimum/habana/accelerate/accelerator.py index fd98f776eb..f5f944a5a5 100644 --- a/optimum/habana/accelerate/accelerator.py +++ b/optimum/habana/accelerate/accelerator.py @@ -21,6 +21,7 @@ from dataclasses import make_dataclass from types import MethodType +import accelerate.utils.other import torch from accelerate import Accelerator from accelerate.accelerator import _split_batches @@ -66,7 +67,6 @@ logger = get_logger(__name__) -# TODO: Compare to fullgraph=False in torch.compile def compile_regions(model, compile_kwargs): if isinstance(model, torch.nn.ModuleList): for name, module in model.named_children(): @@ -705,3 +705,10 @@ def prepare_data_loader( ) self._dataloaders.append(prepared_data_loader) return prepared_data_loader + + +def patch_has_compiled_regions(*args, **kwargs): + return False + + +accelerate.utils.other.has_compiled_regions = patch_has_compiled_regions