Skip to content

fix(gemma): Replace hardcoded CUDA calls in GemmaFixedRotaryEmbedding for XPU support#4928

Open
cheehook wants to merge 1 commit into
unslothai:mainfrom
cheehook:fix-gemma-xpu-error
Open

fix(gemma): Replace hardcoded CUDA calls in GemmaFixedRotaryEmbedding for XPU support#4928
cheehook wants to merge 1 commit into
unslothai:mainfrom
cheehook:fix-gemma-xpu-error

Conversation

@cheehook

@cheehook cheehook commented Apr 9, 2026

Copy link
Copy Markdown
Contributor

Description

Loading Gemma v1 models on Intel XPU systems crashes at GemmaFixedRotaryEmbedding.__init__ in unsloth/models/gemma.py

This does not affect Gemma 2/3, which use transformers' built-in rotary embeddings and never hit this code path.

(unsloth) sdp@emr816608-vm01:~/ch/frameworks.ai.trainingframework.recipes/sandbox/dpo$ ZE_AFFINITY_MASK=1 python dpo_finetuning.py --config configs/gemma3-4b-it_unsloth.yaml 
Skipping import of cpp extensions due to incompatible torch version 2.11.0+xpu for torchao version 0.16.0             Please see https://github.com/pytorch/ao/issues/2919 for more info
Target device: xpu
XPU available: True
XPU device count: 1
XPU device name: Intel(R) Arc(TM) Pro B60 Graphics
/home/sdp/ch/frameworks.ai.trainingframework.recipes/sandbox/dpo/dpo_finetuning.py:795: UserWarning: WARNING: Unsloth should be imported before [transformers] to ensure all optimizations are applied. Your code may run slower or encounter memory issues without these optimizations.

Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth import PatchDPOTrainer, is_bfloat16_supported
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
[Unsloth] backend enabled
==((====))==  Unsloth 2026.3.17: Fast Gemma patching. Transformers: 4.57.6. vLLM: 0.18.1rc1.dev189+gaee4c1468.xpu.
   \\   /|    Intel(R) Arc(TM) Pro B60 Graphics. Num GPUs = 1. Max memory: 23.906 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.11.0+xpu. Intel Toolkit: 20250302. Triton: 3.7.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: https://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Traceback (most recent call last):
  File "/home/sdp/ch/frameworks.ai.trainingframework.recipes/sandbox/dpo/dpo_finetuning.py", line 1090, in <module>
    main(args.config)
  File "/home/sdp/ch/frameworks.ai.trainingframework.recipes/sandbox/dpo/dpo_finetuning.py", line 802, in main
    model, tokenizer = setup_unsloth_model(cfg)
                       ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sdp/ch/frameworks.ai.trainingframework.recipes/sandbox/dpo/dpo_finetuning.py", line 587, in setup_unsloth_model
    model, tokenizer = FastLanguageModel.from_pretrained(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sdp/ch/unsloth/unsloth/models/loader.py", line 711, in from_pretrained
    model, tokenizer = dispatch_model.from_pretrained(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sdp/ch/unsloth/unsloth/models/llama.py", line 2430, in from_pretrained
    model = AutoModelForCausalLM.from_pretrained(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sdp/miniforge3/envs/unsloth-2026.3.4/lib/python3.12/site-packages/transformers/models/auto/auto_factory.py", line 604, in from_pretrained
    return model_class.from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sdp/miniforge3/envs/unsloth-2026.3.4/lib/python3.12/site-packages/transformers/modeling_utils.py", line 277, in _wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/sdp/miniforge3/envs/unsloth-2026.3.4/lib/python3.12/site-packages/transformers/modeling_utils.py", line 4971, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sdp/miniforge3/envs/unsloth-2026.3.4/lib/python3.12/site-packages/transformers/models/gemma/modeling_gemma.py", line 429, in __init__
    self.model = GemmaModel(config)
                 ^^^^^^^^^^^^^^^^^^
  File "/home/sdp/miniforge3/envs/unsloth-2026.3.4/lib/python3.12/site-packages/transformers/models/gemma/modeling_gemma.py", line 345, in __init__
    self.rotary_emb = GemmaRotaryEmbedding(config=config)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sdp/ch/unsloth/unsloth/models/gemma.py", line 309, in __init__
    1, device = torch.cuda.current_device(), dtype = torch.get_default_dtype()
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sdp/miniforge3/envs/unsloth-2026.3.4/lib/python3.12/site-packages/torch/cuda/__init__.py", line 1148, in current_device
    _lazy_init()
  File "/home/sdp/miniforge3/envs/unsloth-2026.3.4/lib/python3.12/site-packages/torch/cuda/__init__.py", line 471, in _lazy_init
    raise AssertionError("Torch not compiled with CUDA enabled")
AssertionError: Torch not compiled with CUDA enabled

Root Cause

GemmaFixedRotaryEmbedding and FastGemmaModel.post_patch contain several hardcoded CUDA calls

Fix

Replace all hardcoded lines with device-agnostic helpers already defined in llama.py, and are already imported into gemma.py via from .llama import *. Fix using the same pattern as in llama.py

torch.device(device) --> torch.device(DEVICE_TYPE_TORCH, device)
torch.cuda.current_device() --> get_current_device()
torch.cuda.empty_cache() --> clean_gpu_cache()

Environment

  • Intel Arc Pro B60 (XPU)
  • Unsloth 2026.3.17
  • Unsloth-zoo 2026.4.3
  • Torch 2.11.0+xpu
  • Triton-xpu 3.7.0

@cheehook cheehook requested a review from danielhanchen as a code owner April 9, 2026 03:41

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors device management in unsloth/models/gemma.py by replacing direct torch.cuda.current_device() calls with a more generalized get_current_device() function. It also explicitly sets the device type to DEVICE_TYPE_TORCH when creating torch.device objects and abstracts torch.cuda.empty_cache() into a clean_gpu_cache() utility. These changes likely aim to improve device abstraction and flexibility. There are no review comments to provide feedback on.

@pre-commit-ci pre-commit-ci Bot requested a review from rolandtannous as a code owner April 9, 2026 12:20
@yao-matrix

Copy link
Copy Markdown

@cheehook is the changes in tests/test_raw_text.py‎ necessary and relevant to your purpose of this PR? suggest to remove the ir-relevant changes.

Replace device-specific CUDA calls in GemmaFixedRotaryEmbedding with
device-agnostic helpers to fix crashes when loading Gemma v1 models on
Intel XPU systems.

Changes:
- torch.cuda.current_device() → get_current_device()
- torch.cuda.empty_cache() → clean_gpu_cache()
- torch.device(device) → torch.device(DEVICE_TYPE_TORCH, device)

Fixes "Torch not compiled with CUDA enabled" error on non-CUDA platforms.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
@cheehook cheehook force-pushed the fix-gemma-xpu-error branch 2 times, most recently from c5b255a to 7131356 Compare April 15, 2026 03:06
@cheehook

Copy link
Copy Markdown
Contributor Author

@cheehook is the changes in tests/test_raw_text.py‎ necessary and relevant to your purpose of this PR? suggest to remove the ir-relevant changes.

Thanks for catching that. I've removed the unrelated test file changes from pre-commit-ci bot and squashed the commits. The PR now contains only the Gemma XPU fix in a single clean commit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants