diff --git a/optimum/habana/accelerate/accelerator.py b/optimum/habana/accelerate/accelerator.py index 67aa9b8984..e5c0ea9ea9 100644 --- a/optimum/habana/accelerate/accelerator.py +++ b/optimum/habana/accelerate/accelerator.py @@ -568,7 +568,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e self._models[-1] = model # torch.compile should be called last and only if the model isn't already compiled. if self.state.dynamo_plugin.backend != GaudiDynamoBackend.NO and not is_compiled_module(model): - model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs()) + if self.dynamic is not None: + model = torch.compile(model, dynamic=self.dynamic, **self.state.dynamo_plugin.to_kwargs()) + else: + model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs()) return model def _prepare_deepspeed(self, *args): diff --git a/tests/test_examples.py b/tests/test_examples.py index c5668e5b7c..78aabcefda 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -319,6 +319,7 @@ def __new__( fsdp=False, torch_compile=False, fp8=False, + compile_dynamic: Optional[bool] = None, ): distribution = "single_card" if multi_card: @@ -355,6 +356,7 @@ def _create_test( fsdp: bool = False, torch_compile: bool = False, fp8: bool = False, + compile_dynamic: Optional[bool] = None, ) -> Callable[[], None]: """ Create a test function that runs an example for a specific (model_name, gaudi_config_name) pair. @@ -479,6 +481,8 @@ def test(self): ): extra_command_line_arguments.append("--torch_compile_backend hpu_backend") extra_command_line_arguments.append("--torch_compile") + if compile_dynamic is not None: + extra_command_line_arguments.append(f"--compile_dynamic {compile_dynamic}") if "--use_hpu_graphs_for_inference" in extra_command_line_arguments: extra_command_line_arguments.remove("--use_hpu_graphs_for_inference") env_variables["PT_HPU_LAZY_MODE"] = "0" @@ -774,6 +778,16 @@ class MultiCardSummarizationExampleTester( TASK_NAME = "cnn_dailymail" +class MultiCardDynamicCompileSummarizationExampleTester( + ExampleTesterBase, + metaclass=ExampleTestMeta, + example_name="run_summarization", + multi_card=True, + compile_dynamic=True, +): + TASK_NAME = "cnn_dailymail" + + class DeepspeedSummarizationExampleTester( ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_summarization", deepspeed=True ):