diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index 821f790604..53e4c3bab6 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -174,9 +174,8 @@ def setup_model(args, model_dtype, model_kwargs, logger): if args.use_hpu_graphs: from habana_frameworks.torch.hpu import wrap_in_hpu_graph - if check_habana_frameworks_version("1.13.0"): - if model.config.model_type == "falcon": - model = wrap_in_hpu_graph(model, hash_with_views=False) + if check_habana_frameworks_version("1.13.0") and model.config.model_type == "falcon": + model = wrap_in_hpu_graph(model, hash_with_views=False) else: model = wrap_in_hpu_graph(model)