Skip to content

add UNSLOTH_ALLOW_CPU=1 path for CPU-only CI / source-inspection tests#5429

Merged
danielhanchen merged 1 commit into
mainfrom
fix/allow-cpu-ci
May 15, 2026
Merged

add UNSLOTH_ALLOW_CPU=1 path for CPU-only CI / source-inspection tests#5429
danielhanchen merged 1 commit into
mainfrom
fix/allow-cpu-ci

Conversation

@danielhanchen
Copy link
Copy Markdown
Member

Summary

Adds an opt-in UNSLOTH_ALLOW_CPU=1 env var so import unsloth.trainer succeeds on hosts without a CUDA/XPU/HIP accelerator. Pairs with unsloth-zoo PR #646, which runs source-inspection drift detectors on CPU-only CI runners and needs the import path to land.

The env var is read exactly once per process via @functools.cache on get_device_type(). Production hosts (with a real accelerator) pay zero runtime cost: the existing fast path runs unchanged.

Changes

File Change
unsloth/device_type.py:65-78 Gate both raise NotImplementedError sites on os.environ.get("UNSLOTH_ALLOW_CPU", "0") == "1"; return "cuda" as the CPU sentinel.
unsloth/_gpu_init.py:212 if DEVICE_TYPE == "cuda": -> if DEVICE_TYPE == "cuda" and torch.cuda.is_available(): so the bf16-support probe doesn't fault on CPU.
unsloth/_gpu_init.py:247 Same guard for the libcuda_dirs()/bnb.functional.lib.* block.
unsloth/_gpu_init.py:359 Gate _patch_trl_trainer() on UNSLOTH_ALLOW_CPU != "1". Prevents trl.{X}Trainer.__init__ from being rebound to _backwards_compatible_trainer, which corrupts inspect.getsource(SFTTrainer.__init__) for downstream drift detectors.
unsloth/models/_utils.py:1196 Same torch.cuda.is_available() guard for get_device_capability at import.
unsloth/models/rl.py:PatchFastRL Early-return under UNSLOTH_ALLOW_CPU=1. Without this, patch_trl_rl_trainers() rebinds trl.SFTTrainer to the compiled UnslothSFTTrainer class and zoo's drift detectors trip on the wrapper source.

Test plan

  • UNSLOTH_ALLOW_CPU=1 python -c "import unsloth.trainer" on a CPU-only venv -> succeeds.
  • trl.SFTTrainer.__init__.__qualname__ stays SFTTrainer.__init__ (not UnslothSFTTrainer.__init__) under the env var.
  • inspect.getsource(trl.trainer.sft_trainer.SFTTrainer) still contains self._signature_columns (drift detector anchor).
  • Without the env var on a CUDA host, TRL is patched normally (UnslothSFTTrainer.__init__). No regression.

Why "cuda" instead of a real "cpu" value

The sentinel "cuda" only needs to let import unsloth.trainer and downstream module imports succeed. Downstream branches that would actually fault (get_device_capability, libcuda_dirs) are now gated by torch.cuda.is_available(). Introducing a real "cpu" device type would touch a lot of code; not the goal here.

Companion PR

  • unsloth-zoo PR Nightly #646 sets this env var in its tests/conftest.py before importing unsloth, then the matrix can run source-inspection drift detectors against the pristine upstream TRL classes.

Lets `import unsloth.trainer` succeed on hosts without a CUDA/XPU/HIP
accelerator (typical of zoo's source-inspection test matrix). The env
var is read exactly once per process via @functools.cache on
`get_device_type()`, so production hosts pay no runtime cost.

Three edits beyond the device_type fallback:

* `_gpu_init.py:212/247` -- the bf16 + libcuda/bnb setup blocks call
  `torch.cuda.get_device_capability()` and `libcuda_dirs()`/`bnb.functional.lib.*`
  unconditionally when DEVICE_TYPE == "cuda". Guard with
  `and torch.cuda.is_available()` so the new CPU-CI sentinel doesn't
  fault those.
* `_gpu_init.py:353` -- gate `_patch_trl_trainer()` (the
  `_backwards_compatible_trainer.__init__` wrapper). Under
  UNSLOTH_ALLOW_CPU we want pristine upstream TRL classes for
  downstream `inspect.getsource(SFTTrainer)` drift detectors.
* `models/_utils.py:1196` -- same `and torch.cuda.is_available()` guard
  for `get_device_capability()` at import time.
* `models/rl.py:PatchFastRL` -- early-return under UNSLOTH_ALLOW_CPU=1
  so the heavier `patch_trl_rl_trainers()` (which replaces
  `trl.SFTTrainer` with the compiled `UnslothSFTTrainer` class)
  doesn't fire either. Without this gate the drift detectors that
  do `inspect.getsource(SFTTrainer)` see the wrapper source and
  spurious fail.

Local sanity: `UNSLOTH_ALLOW_CPU=1 python -c "import unsloth.trainer"`
succeeds on a CPU-only venv, `trl.SFTTrainer.__init__.__qualname__`
stays `SFTTrainer.__init__` (not `UnslothSFTTrainer.__init__`), and
`inspect.getsource(SFTTrainer)` still contains `self._signature_columns`.
Without the env var on a CUDA host, TRL is still patched normally
(verified `UnslothSFTTrainer.__init__`).
@danielhanchen danielhanchen merged commit cb15a7a into main May 15, 2026
25 of 32 checks passed
@danielhanchen danielhanchen deleted the fix/allow-cpu-ci branch May 15, 2026 03:27
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 6ef7cbd61c

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment thread unsloth/_gpu_init.py

# Torch 2.4 has including_emulation
if DEVICE_TYPE == "cuda":
if DEVICE_TYPE == "cuda" and torch.cuda.is_available():
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Gate the zoo device probe before CPU fallback

On CPU-only CI with UNSLOTH_ALLOW_CPU=1, this guard is reached only after _gpu_init has already imported DEVICE_TYPE from unsloth_zoo.device_type (lines 125-130). With any installed unsloth_zoo version that satisfies the current dependency but does not yet include the companion fallback, that import still raises NotImplementedError before the new local unsloth/device_type.py fallback or this torch.cuda.is_available() guard can run, so the advertised import unsloth.trainer path remains broken unless users happen to upgrade zoo out-of-band.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for CPU-only CI environments by implementing an UNSLOTH_ALLOW_CPU environment variable, which allows the library to bypass GPU hardware requirements and skip TRL trainer patching during source-inspection tests. The changes include adding torch.cuda.is_available() guards to prevent hardware probing on non-GPU systems. Review feedback points out a missing guard for a CUDA capability check in unsloth/models/_utils.py that could still trigger misleading warnings on CPU hosts and suggests caching the environment variable lookup to improve code maintainability.

Comment thread unsloth/models/_utils.py
HAS_FLASH_ATTENTION_SOFTCAPPING = False

if DEVICE_TYPE == "cuda":
if DEVICE_TYPE == "cuda" and torch.cuda.is_available():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While this guard correctly prevents hardware probes on CPU hosts, a similar call to torch.cuda.get_device_capability() at line 1312 is currently unguarded. Since DEVICE_TYPE is now spoofed as "cuda" on CPU hosts when UNSLOTH_ALLOW_CPU=1, that site will attempt to call into the CUDA API and raise a RuntimeError during import. Although it is wrapped in a broad try...except Exception, it will still trigger a misleading "Switching to PyTorch attention since your Xformers is broken" warning if logging is enabled. Please apply the same torch.cuda.is_available() guard to line 1312.

Comment thread unsloth/device_type.py
if not torch.accelerator.is_available():
# Test-only CPU fallback. The env var is read exactly once per
# process because get_device_type is @functools.cache'd.
if os.environ.get("UNSLOTH_ALLOW_CPU", "0") == "1":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The environment variable UNSLOTH_ALLOW_CPU is read multiple times within this function and across other modules (_gpu_init.py, rl.py). Since get_device_type is cached, consider capturing this value into a constant or a local variable at the start of the function to improve maintainability and avoid repeated string lookups.

josephGuo pushed a commit to josephGuo/unsloth that referenced this pull request May 16, 2026
…othai#5473)

unslothai#5429 (cb15a7a) tightened three production-path branches to
DEVICE_TYPE == "cuda" and torch.cuda.is_available() and added a
new else: SUPPORTS_BFLOAT16 = False arm to let `import unsloth.trainer`
survive on a CPU-only CI host. We already ship the package on Intel
XPU / AMD HIP / NVIDIA CUDA and don't want any extra branching in
those hot paths.

Move the entire CPU-CI handling to one place -- the top of
unsloth/device_type.get_device_type() -- so the UNSLOTH_ALLOW_CPU=1
sentinel short-circuits detection and returns "cuda" before any
torch probe runs. Every downstream DEVICE_TYPE == "cuda" branch
then behaves identically to a real CUDA host, with no additional
checks. The two existing duplicate UNSLOTH_ALLOW_CPU returns
later in the function are dropped (the new top-of-function check
covers both).

Revert the three call-site changes:
- unsloth/_gpu_init.py:212  -> back to `if DEVICE_TYPE == "cuda":`
- unsloth/_gpu_init.py:247  -> back to `if DEVICE_TYPE == "cuda":`
- unsloth/models/_utils.py:1207 -> back to `if DEVICE_TYPE == "cuda":`
- unsloth/_gpu_init.py: drop the new `else: SUPPORTS_BFLOAT16 = False`
  branch (dead under the top-of-function short-circuit).

Keep the two env-var gates that are needed for zoo's drift detectors
to inspect pristine TRL source (no behavioural change on production
hosts that never set UNSLOTH_ALLOW_CPU):
- unsloth/_gpu_init.py: `if env != "1": _patch_trl_trainer()`
- unsloth/models/rl.py:PatchFastRL: `if env == "1": return`

Verified:
- CUDA_VISIBLE_DEVICES=5 python -c "import unsloth.trainer" produces
  UnslothSFTTrainer.__init__ (TRL still patched on real hosts).
- UNSLOTH_ALLOW_CPU=1 + aggressive cuda spoof import succeeds and
  trl.SFTTrainer.__init__.__qualname__ stays SFTTrainer.__init__.
- pytest tests/_zoo_compiler_cache_shim.py -> 5 passed, 1 skipped.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant