fix(gemma): Replace hardcoded CUDA calls in GemmaFixedRotaryEmbedding for XPU support#4928
fix(gemma): Replace hardcoded CUDA calls in GemmaFixedRotaryEmbedding for XPU support#4928cheehook wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors device management in unsloth/models/gemma.py by replacing direct torch.cuda.current_device() calls with a more generalized get_current_device() function. It also explicitly sets the device type to DEVICE_TYPE_TORCH when creating torch.device objects and abstracts torch.cuda.empty_cache() into a clean_gpu_cache() utility. These changes likely aim to improve device abstraction and flexibility. There are no review comments to provide feedback on.
|
@cheehook is the changes in |
Replace device-specific CUDA calls in GemmaFixedRotaryEmbedding with device-agnostic helpers to fix crashes when loading Gemma v1 models on Intel XPU systems. Changes: - torch.cuda.current_device() → get_current_device() - torch.cuda.empty_cache() → clean_gpu_cache() - torch.device(device) → torch.device(DEVICE_TYPE_TORCH, device) Fixes "Torch not compiled with CUDA enabled" error on non-CUDA platforms. Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
c5b255a to
7131356
Compare
Thanks for catching that. I've removed the unrelated test file changes from pre-commit-ci bot and squashed the commits. The PR now contains only the Gemma XPU fix in a single clean commit. |
Description
Loading Gemma v1 models on Intel XPU systems crashes at
GemmaFixedRotaryEmbedding.__init__inunsloth/models/gemma.pyThis does not affect Gemma 2/3, which use transformers' built-in rotary embeddings and never hit this code path.
Root Cause
GemmaFixedRotaryEmbeddingandFastGemmaModel.post_patchcontain several hardcoded CUDA callsFix
Replace all hardcoded lines with device-agnostic helpers already defined in
llama.py, and are already imported intogemma.pyviafrom .llama import *. Fix using the same pattern as inllama.pytorch.device(device)-->torch.device(DEVICE_TYPE_TORCH, device)torch.cuda.current_device()-->get_current_device()torch.cuda.empty_cache()-->clean_gpu_cache()Environment