Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ def is_ninja_available():
return True


def is_ipex_available():
def is_ipex_available(min_version: str = ""):
def get_major_and_minor_from_version(full_version):
return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor)

Expand All @@ -880,6 +880,8 @@ def get_major_and_minor_from_version(full_version):
f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again."
)
return False
if min_version:
return version.parse(_ipex_version) >= version.parse(min_version)
return True


Expand Down
9 changes: 6 additions & 3 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import numpy as np
import pytest
from packaging import version
from parameterized import parameterized

from transformers import AutoConfig, is_torch_available, pipeline, set_seed
Expand All @@ -44,6 +45,7 @@
slow,
torch_device,
)
from transformers.utils import is_ipex_available

from ..test_modeling_common import floats_tensor, ids_tensor
from .test_framework_agnostic import GenerationIntegrationTestsMixin
Expand Down Expand Up @@ -675,10 +677,11 @@ def test_beam_search_generate_dict_outputs_use_cache(self):
@require_torch_multi_accelerator
@pytest.mark.generate
def test_model_parallel_beam_search(self):
for model_class in self.all_generative_model_classes:
if "xpu" in torch_device:
return unittest.skip(reason="device_map='auto' does not work with XPU devices")
if "xpu" in torch_device:
if not (is_ipex_available("2.5") or version.parse(torch.__version__) >= version.parse("2.6")):
self.skipTest(reason="device_map='auto' does not work with XPU devices")

for model_class in self.all_generative_model_classes:
if model_class._no_split_modules is None:
continue

Expand Down