diff --git a/examples/text-generation/utils.py b/examples/text-generation/utils.py index d6152c90fa..211e58861c 100644 --- a/examples/text-generation/utils.py +++ b/examples/text-generation/utils.py @@ -36,7 +36,7 @@ model_on_meta, write_checkpoints_json, ) -from optimum.habana.utils import check_habana_frameworks_min_version, check_optimum_habana_min_version, set_seed +from optimum.habana.utils import check_habana_frameworks_version, check_optimum_habana_min_version, set_seed def override_print(enable): @@ -132,7 +132,8 @@ def setup_model(args, model_dtype, model_kwargs, logger): if args.use_hpu_graphs: from habana_frameworks.torch.hpu import wrap_in_hpu_graph - if check_habana_frameworks_min_version("1.13.0"): + # TODO: remove the following check from SynapseAI v1.15 + if check_habana_frameworks_version("1.13.0"): if model.config.model_type == "falcon": args.skip_hash_with_views = True model = wrap_in_hpu_graph(model, hash_with_views=not args.skip_hash_with_views) diff --git a/optimum/habana/utils.py b/optimum/habana/utils.py index 20fe0735ac..702aea456b 100644 --- a/optimum/habana/utils.py +++ b/optimum/habana/utils.py @@ -336,6 +336,13 @@ def check_habana_frameworks_min_version(min_version): return True +def check_habana_frameworks_version(req_version): + """ + Checks if the installed version of `habana_frameworks` is equal to `req_version`. + """ + return get_habana_frameworks_version() == version.parse(req_version) + + def get_device_name(): """ Returns the name of the current device: Gaudi or Gaudi2.