Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion optimum/habana/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,10 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
self._models[-1] = model
# torch.compile should be called last and only if the model isn't already compiled.
if self.state.dynamo_plugin.backend != GaudiDynamoBackend.NO and not is_compiled_module(model):
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
if self.dynamic is not None:
model = torch.compile(model, dynamic=self.dynamic, **self.state.dynamo_plugin.to_kwargs())
else:
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
return model

def _prepare_deepspeed(self, *args):
Expand Down
14 changes: 14 additions & 0 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def __new__(
fsdp=False,
torch_compile=False,
fp8=False,
compile_dynamic: Optional[bool] = None,
):
distribution = "single_card"
if multi_card:
Expand Down Expand Up @@ -355,6 +356,7 @@ def _create_test(
fsdp: bool = False,
torch_compile: bool = False,
fp8: bool = False,
compile_dynamic: Optional[bool] = None,
) -> Callable[[], None]:
"""
Create a test function that runs an example for a specific (model_name, gaudi_config_name) pair.
Expand Down Expand Up @@ -479,6 +481,8 @@ def test(self):
):
extra_command_line_arguments.append("--torch_compile_backend hpu_backend")
extra_command_line_arguments.append("--torch_compile")
if compile_dynamic is not None:
extra_command_line_arguments.append(f"--compile_dynamic {compile_dynamic}")
if "--use_hpu_graphs_for_inference" in extra_command_line_arguments:
extra_command_line_arguments.remove("--use_hpu_graphs_for_inference")
env_variables["PT_HPU_LAZY_MODE"] = "0"
Expand Down Expand Up @@ -774,6 +778,16 @@ class MultiCardSummarizationExampleTester(
TASK_NAME = "cnn_dailymail"


class MultiCardDynamicCompileSummarizationExampleTester(
ExampleTesterBase,
metaclass=ExampleTestMeta,
example_name="run_summarization",
multi_card=True,
compile_dynamic=True,
):
TASK_NAME = "cnn_dailymail"


class DeepspeedSummarizationExampleTester(
ExampleTesterBase, metaclass=ExampleTestMeta, example_name="run_summarization", deepspeed=True
):
Expand Down