Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions unsloth/_gpu_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@
del patch_peft_weight_converter_compatibility

# Torch 2.4 has including_emulation
if DEVICE_TYPE == "cuda":
if DEVICE_TYPE == "cuda" and torch.cuda.is_available():
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Gate the zoo device probe before CPU fallback

On CPU-only CI with UNSLOTH_ALLOW_CPU=1, this guard is reached only after _gpu_init has already imported DEVICE_TYPE from unsloth_zoo.device_type (lines 125-130). With any installed unsloth_zoo version that satisfies the current dependency but does not yet include the companion fallback, that import still raises NotImplementedError before the new local unsloth/device_type.py fallback or this torch.cuda.is_available() guard can run, so the advertised import unsloth.trainer path remains broken unless users happen to upgrade zoo out-of-band.

Useful? React with 👍 / 👎.

major_version, minor_version = torch.cuda.get_device_capability()
SUPPORTS_BFLOAT16 = major_version >= 8

Expand All @@ -233,12 +233,18 @@ def is_bf16_supported():
# torch.xpu.is_bf16_supported() does not have including_emulation
# set SUPPORTS_BFLOAT16 as torch.xpu.is_bf16_supported()
SUPPORTS_BFLOAT16 = torch.xpu.is_bf16_supported()
else:
# CPU-only CI under UNSLOTH_ALLOW_CPU=1. We can't probe device
# capability, so assume no bf16 -- training won't run on this host
# anyway, this branch only exists to let `import unsloth.trainer`
# succeed for source-inspection tests.
SUPPORTS_BFLOAT16 = False

# For Gradio HF Spaces?
# if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ:
import triton

if DEVICE_TYPE == "cuda":
if DEVICE_TYPE == "cuda" and torch.cuda.is_available():
libcuda_dirs = lambda: None
if Version(triton.__version__) >= Version("3.0.0"):
try:
Expand Down Expand Up @@ -349,5 +355,10 @@ def is_bf16_supported():
launch_openenv,
)

# Patch TRL trainers for backwards compatibility
_patch_trl_trainer()
# Patch TRL trainers for backwards compatibility.
# Skipped under UNSLOTH_ALLOW_CPU=1 (CPU-only CI) because rebinding
# trl.SFTTrainer.__init__ to a generic wrapper changes
# inspect.getsource(SFTTrainer.__init__) and corrupts downstream
# drift detectors that anchor on the pristine upstream source.
if os.environ.get("UNSLOTH_ALLOW_CPU", "0") != "1":
_patch_trl_trainer()
6 changes: 6 additions & 0 deletions unsloth/device_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def get_device_type():
# Check torch.accelerator
if hasattr(torch, "accelerator"):
if not torch.accelerator.is_available():
# Test-only CPU fallback. The env var is read exactly once per
# process because get_device_type is @functools.cache'd.
if os.environ.get("UNSLOTH_ALLOW_CPU", "0") == "1":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The environment variable UNSLOTH_ALLOW_CPU is read multiple times within this function and across other modules (_gpu_init.py, rl.py). Since get_device_type is cached, consider capturing this value into a constant or a local variable at the start of the function to improve maintainability and avoid repeated string lookups.

return "cuda"
raise NotImplementedError(
"Unsloth cannot find any torch accelerator? You need a GPU."
)
Expand All @@ -73,6 +77,8 @@ def get_device_type():
f"But `torch.accelerator.current_accelerator()` works with it being = `{accelerator}`\n"
f"Please reinstall torch - it's most likely broken :("
)
if os.environ.get("UNSLOTH_ALLOW_CPU", "0") == "1":
return "cuda"
raise NotImplementedError(
"Unsloth currently only works on NVIDIA, AMD and Intel GPUs."
)
Expand Down
2 changes: 1 addition & 1 deletion unsloth/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,7 @@ def _is_openai_available():
HAS_FLASH_ATTENTION = False
HAS_FLASH_ATTENTION_SOFTCAPPING = False

if DEVICE_TYPE == "cuda":
if DEVICE_TYPE == "cuda" and torch.cuda.is_available():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While this guard correctly prevents hardware probes on CPU hosts, a similar call to torch.cuda.get_device_capability() at line 1312 is currently unguarded. Since DEVICE_TYPE is now spoofed as "cuda" on CPU hosts when UNSLOTH_ALLOW_CPU=1, that site will attempt to call into the CUDA API and raise a RuntimeError during import. Although it is wrapped in a broad try...except Exception, it will still trigger a misleading "Switching to PyTorch attention since your Xformers is broken" warning if logging is enabled. Please apply the same torch.cuda.is_available() guard to line 1312.

major_version, minor_version = torch.cuda.get_device_capability()
torch.cuda.get_device_capability = functools.cache(torch.cuda.get_device_capability)

Expand Down
5 changes: 5 additions & 0 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2270,6 +2270,11 @@ def patch_trl_vllm_generation():
def PatchFastRL(algorithm = None, FastLanguageModel = None):
if FastLanguageModel is not None:
PatchRL(FastLanguageModel)
# Under UNSLOTH_ALLOW_CPU=1 (CPU-only CI), skip TRL trainer rewriting so
# downstream `inspect.getsource(trl.SFTTrainer)` drift detectors see the
# pristine upstream class, not the compiled Unsloth* wrappers.
if os.environ.get("UNSLOTH_ALLOW_CPU", "0") == "1":
return
# Install the disable_gradient_checkpointing noop BEFORE
# patch_trl_rl_trainers. patch_trl_rl_trainers imports extra trl.* trainer
# submodules while generating the compiled cache; any new trl.* modules
Expand Down
Loading