diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 66690c9b05..de8f27762c 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -177,7 +177,7 @@ def patch_scoped_linear_all_reduce(model): def get_torch_compiled_model(model): - if model.config.model_type in ["gpt_bigcode", "mpt", "bloom"]: + if model.config.model_type in ["gpt_bigcode", "mpt", "bloom", "gpt2"]: model.transformer = torch.compile( model.transformer, backend="hpu_backend", options={"keep_input_mutations": True} )