Skip to content

Make bitsandbytes optional on ROCm and add bf16 helper#4000

Closed
danielhanchen wants to merge 3346 commits into
mainfrom
fix/amd-optional-bnb
Closed

Make bitsandbytes optional on ROCm and add bf16 helper#4000
danielhanchen wants to merge 3346 commits into
mainfrom
fix/amd-optional-bnb

Conversation

@danielhanchen
Copy link
Copy Markdown
Member

Summary\n\nThis makes Unsloth more robust on environments where bitsandbytes is not available (notably ROCm), while keeping CUDA/NVIDIA behavior unchanged when bitsandbytes is installed.\n\nChanges:\n- Guard bitsandbytes imports in several modules so works without bitsandbytes.\n- Avoid crashes by building type-tuples only from available classes.\n- Add a stable helper on HIP/XPU for backwards notebook compatibility.\n- Guard the vLLM aimv2 patch when vLLM package metadata is missing (module present but no dist-info).\n- GRPO: align mask/coef lengths in the loss path when left-padding creates a length mismatch.\n- PEFT compatibility: drop kwarg when running with older .\n\n## Testing\n\n- on the touched modules.\n- Validated in a ROCm notebook-suite environment (no bitsandbytes installed) where Unsloth notebooks need to import and train successfully.

oKatanaaa and others added 30 commits December 11, 2025 03:21
* skip xpu fbgemm fp8

* Apply suggestion from @gemini-code-assist[bot]

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestion from @danielhanchen

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
* pipe kwargs through mistral model

* simplify / bugfix

* bugfix for train_on_completions_only

* wire up is_unsupported_model

* nits, edge cases
* Fix get_input_embeds call for VLMs

* patch input_require_grads instead

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup old patch

* cleanup old patch

* cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestion from @danielhanchen

* use logger instead of prints

* Move unsloth present set

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
* Update torchao save

* up

* up

* up

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestion from @danielhanchen

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
* Update _utils.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [FIX] [Transformers] VLM input embeds fix for gradients (#3715)

* Fix get_input_embeds call for VLMs

* patch input_require_grads instead

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* cleanup old patch

* cleanup old patch

* cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Apply suggestion from @danielhanchen

* use logger instead of prints

* Move unsloth present set

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>

* Update rope_embedding.py

* Fixes

* Update _utils.py

* Update import_fixes.py

* Update rl_replacements.py

* fix_openenv_no_vllm

* Fix

* Update __init__.py

* Update __init__.py

* Update __init__.py

* Update import_fixes.py

* Update import_fixes.py

* Update import_fixes.py

* logger

* Update __init__.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update __init__.py

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>
danielhanchen and others added 20 commits February 3, 2026 02:48
When users pass `num_train_epochs=None` to GRPOConfig (relying on
max_steps to control training duration), Trainer.__init__ fails with:

  TypeError: '>' not supported between instances of 'NoneType' and 'int'

This happens because transformers.Trainer does `args.num_train_epochs > 0`
in its __init__ which fails when the value is None.

This fix converts None to 3.0 (the default) before Trainer initialization.
The actual training duration is still controlled by max_steps since it
takes precedence when both are set.

Example that now works:
```python
config = GRPOConfig(
    num_train_epochs=None,  # Previously caused TypeError
    max_steps=500,          # This controls actual duration
    ...
)
```
…#3971)

* Add TRL truncation regression and metadata loss fixes

Fix 1: TRL 0.24.0-0.25.1 right-truncation regression
- These versions pass max_length=self.max_prompt_length and truncation=True
  to the tokenizer, which right-truncates prompts and strips the assistant
  turn suffix
- Use regex to remove these kwargs from the generated code

Fix 3: Metadata loss for chat_template_kwargs
- TRL 0.24.0+ extracts prompts = [x["prompt"] for x in inputs], losing metadata
  like reasoning_effort
- Inject code to store per-sample chat_template_kwargs on self before extraction
- Preserve these kwargs in prompts_text generation for all TRL versions

Tested with TRL versions 0.22.2, 0.23.1, 0.24.0, 0.25.1, 0.26.2, and 0.27.1.

* Update Fix 1 comment with detailed TRL version behavior explanation

Expand the comment for the TRL 0.24.0-0.25.1 truncation regression fix
to clarify what each TRL version does:

- TRL 0.22.2-0.23.1: Uses truncate_with_protected_tokens() for smart
  truncation that preserves rightmost tokens and protects special tokens
- TRL 0.24.0-0.25.1: Removed smart truncation, passes kwargs directly
  to tokenizer (max_length, truncation=True, add_special_tokens=False)
- TRL 0.26.2+: Removed these kwargs entirely

The fix removes these problematic kwargs so 0.24.0-0.25.1 behaves like
0.26.2+ (no tokenizer-level truncation).

---------

Co-authored-by: danielhanchen <danielhanchen@users.noreply.github.com>
vLLM's distributed module (device_communicators) crashes with std::bad_alloc
when imported on SM100 GPUs (B200/B100/Blackwell) with torch < 2.9.0.

This adds an early check that runs before vLLM is imported, providing a
helpful error message instead of a cryptic C++ exception.

The check:
1. Detects if vLLM is installed
2. Checks if torch version is < 2.9.0
3. Checks if any GPU is SM100 (Blackwell)
4. If all conditions met, raises RuntimeError with clear upgrade instructions
…h versions (#3978)

* Fix torchvision compatibility check for source builds and future torch versions

The torchvision version check raised a hard ImportError for custom/source-built
PyTorch installations (e.g. AMD ROCm from source with +git* suffixes), even when
the actual build was functional. This also silently skipped any torch version
not already in the hardcoded table, giving no warning at all for future releases.

Changes:
- Detect custom/source builds by checking the raw version string's local
  identifier against known standard prefixes (cu, rocm, cpu, xpu). Our custom
  Version() strips local identifiers via regex, so detection must happen on the
  raw string before parsing.
- Downgrade to a warning (instead of ImportError) for custom/source builds,
  since their version numbers may not follow standard PyPI release pairings.
- Add formula-based inference for future torch versions not yet in the table.
  The torch->torchvision minor version formula (torch 2.x -> tv 0.(x+15)) has
  held for every release from torch 2.0 through 2.9. For formula-predicted
  versions, mismatches produce a warning rather than a hard error.
- Add UNSLOTH_SKIP_TORCHVISION_CHECK=1 env var to skip the check entirely.
- Wrap importlib_version and Version calls in try/except so broken metadata
  never crashes the import.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address review: stricter regex, case insensitivity, pre-release detection

Fixes three edge cases found during review:

1. Regex precision: cu/xpu now require a trailing digit (cu\d, xpu\d) to
   avoid false negatives on suffixes like "+custom_build" that happen to
   start with "cu". cpu/xpu match as exact strings only.

2. Case insensitivity: added re.IGNORECASE so "+ROCM6.3" and "+CPU" are
   correctly recognized as standard builds rather than custom ones.

3. Pre-release detection: nightly/dev/alpha/beta/rc builds with standard
   CUDA/ROCm suffixes (e.g. "2.7.0.dev20250301+cu124") now produce a
   warning instead of a hard ImportError. These builds commonly have
   version mismatches that are expected during development.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Address PR review comments: fullmatch, env var casing, torchvision pre-release

1. Switch re.match to re.fullmatch for the custom build regex so the
   entire local identifier must match. Fixes false negatives where
   suffixes like +cu124_custom were misclassified as standard because
   re.match only checked the start of the string.

2. Use .lower() for the UNSLOTH_SKIP_TORCHVISION_CHECK env var so
   any casing of "true" / "TRUE" / etc. is accepted.

3. Check torchvision_version_raw for pre-release tags in addition to
   torch_version_raw, so a stable torch paired with a nightly
   torchvision (e.g. 0.23.0.dev...) also gets a warning instead of
   a hard ImportError.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Update rl_replacements.py

* Update rl_replacements.py

* Update rl.py

* Update rl_replacements.py

* Update rl_replacements.py

* Update rl.py

* Update rl.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update rl_replacements.py

* Update rl.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update rl_replacements.py, remove chat template from codexes commits

* Update rl.py, got rid of gradient checkpointing code that did not work

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Enable flex attention by default

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Avoid dropping flex attention when SDPA unsupported

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
GPT-OSS models use eager attention during inference because flex
attention returns incorrect results (likely due to left padding).
However, when _attn_implementation is set to "flex_attention",
transformers creates BlockMask objects which cause a TypeError
when passed to the eager attention path:

  TypeError: unsupported operand type(s) for +=: 'Tensor' and 'BlockMask'

This fix excludes GPT-OSS from using flex_attention, keeping it on
the eager path to avoid the BlockMask/Tensor type mismatch.
* Silence third-party deprecation warnings and fix socket resource leak

- Add warning filters for TorchAO deprecated import paths
- Filter SWIG builtin type warnings from bitsandbytes/triton
- Filter Triton autotuner deprecation warnings
- Filter Python 3.12+ multiprocessing fork warnings
- Filter resource warnings for unclosed sockets/files
- Fix socket leak in has_internet() by properly closing socket

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
_patch_trl_rl_trainers enumerates all trainer modules from dir(trl.trainer)
and attempts to import each one. Modules like alignprop_trainer fail because
they depend on optional packages (diffusers) that may not be installed. The
failure is harmless but the print() call produces noise on every import.

Change print() to logger.info() so these messages only appear when
UNSLOTH_ENABLE_LOGGING=1.

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
- Add cu126/cu128/cu130 xformers 0.0.34 wheel dependencies for torch 2.10
- Add cu126-torch2100, cu128-torch2100, cu130-torch2100 meta-dependencies
- Add cu126-ampere-torch2100, cu128-ampere-torch2100, cu130-ampere-torch2100 variants
- Update _auto_install.py version detection for torch 2.10.x
- Add CUDA check for torch 2.10 (requires CUDA 12.6, 12.8, or 13.0)
- Update README.md with torch 2.10 installation instructions

Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com>
* Improve MoE performance

* small changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix imports

* disable autotune

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* LoRA for MoE

* Make autotune default

* make dy contiguous

* use non lora model as base for RL

* Revert "use non lora model as base for RL"

This reverts commit 6e15d22.

* fixup derp

* non TMA [T4]

* Revert "non TMA [T4]"

This reverts commit 0d3cc76.

* Fixes for VL MoE and v5 transformers

* [transformers] [v5] remove unused hybridcache (#3910)

* remote unused hybridcache

* cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* No double compile for qwen3moe

* Fix top_k on trl GRPO

* Recognise GLM as MoE

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix missing RotaryEmbeddingConfigMixin

* Licensing for autotuning cache

* Cleanup

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Erland366 <erland.pg366@gmail.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
When datasets library has torchcodec installed but FFmpeg libraries
are missing, torchcodec raises a RuntimeError during import. The
exception handler only caught ImportError and AttributeError, causing
the error to propagate and crash Unsloth imports in environments
like Colab where FFmpeg may not be installed.

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
The cuda.cutlass_epilogue_fusion_enabled and cuda.cutlass_tma_only
inductor config options were added in PyTorch 2.8.0. Using these
options on older PyTorch versions causes a RuntimeError during
GRPOTrainer initialization.

This fix adds a version check to only include these options when
running PyTorch 2.8.0 or later, allowing GRPO training to work on
older PyTorch versions (e.g., Colab environments with PyTorch 2.5-2.7).

Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com>
* Disable torchcodec in transformers when FFmpeg is missing

When torchcodec is installed but FFmpeg libraries are unavailable,
transformers still thinks torchcodec is available (via find_spec check)
and tries to use it for audio loading, causing RuntimeError.

This adds disable_torchcodec_if_broken() which tests if torchcodec can
actually load its native libraries, and if not, patches transformers'
_torchcodec_available to False so it falls back to librosa instead.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Daniel Han <danielhanchen@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…3999)

On Windows and macOS (Python 3.8+), multiprocessing uses the spawn
start method. When datasets .map(num_proc=N) is called, it creates a
Pool(N) which re-imports __main__ in each worker, causing infinite
recursion and a RuntimeError during bootstrapping.

Guard the auto-computed dataset_num_proc in the generated Config
__init__ by checking multiprocessing.get_start_method() != 'fork'.
When the start method is not fork (spawn/forkserver), force
dataset_num_proc = None so datasets takes the single-process path.
Linux fork behavior is unchanged.

Also replace the fixed memory threshold logic with the simpler
adaptive approach: cap at 64, then min(num_proc, int(available_gb)),
with a safety floor of 1 when available memory is at or below 2GB.

Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly improves Unsloth's adaptability and stability across diverse computing environments. By making bitsandbytes an optional dependency, it broadens support for platforms like ROCm. Additionally, it enhances compatibility with various versions of external libraries such as PEFT and vLLM, and introduces a more robust BF16 support mechanism. These changes collectively aim to make Unsloth more resilient to different software configurations and hardware setups.

Highlights

  • Optional Bitsandbytes Integration: Made bitsandbytes imports conditional across several modules, enabling Unsloth to run robustly in environments where bitsandbytes is unavailable, such as ROCm.
  • Enhanced PEFT Compatibility: Introduced dynamic checks to ensure compatibility with older PEFT versions by conditionally removing the ensure_weight_tying argument from LoraConfig if not supported.
  • Improved vLLM Robustness: Added error handling for vLLM version detection to prevent crashes when its package metadata is missing, ensuring smoother integration.
  • Stable BF16 Support Helper: Provided a stable is_bf16_supported helper function for HIP/XPU backends, improving backward compatibility for existing notebooks.
  • GRPO Loss Path Correction: Fixed a potential issue in the GRPO loss calculation by aligning mask and coefficient lengths when left-padding leads to mismatches.
  • Dynamic Type-Tuple Construction: Ensured that type-tuples for Linear4bit classes are built only from successfully imported bitsandbytes and peft components, preventing errors when these libraries are not fully present.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth/init.py
    • Added a stable is_bf16_supported helper for HIP/XPU for backwards notebook compatibility.
  • unsloth/import_fixes.py
    • Wrapped importlib_version("vllm") in a try-except block to handle missing vLLM package metadata.
  • unsloth/kernels/utils.py
    • Guarded bitsandbytes import with a try-except block and conditionalized bnb related function assignments.
  • unsloth/models/_utils.py
    • Guarded bitsandbytes and peft.tuners.lora.bnb imports with try-except blocks.
  • unsloth/models/granite.py
    • Guarded Bnb_Linear4bit and Peft_Linear4bit imports, and dynamically created _BNB_LINEAR_TYPES tuple.
    • Modified isinstance check to use the new tuple.
  • unsloth/models/llama.py
    • Added logic to remove ensure_weight_tying argument from LoraConfig if not supported by the installed PEFT version.
  • unsloth/models/rl_replacements.py
    • Implemented logic in masked_batch_mean to align mask/coef lengths to prevent mismatches caused by left-padding.
  • unsloth/save.py
    • Guarded Bnb_Linear4bit and Peft_Linear4bit imports, dynamically created _MERGE_LORA_LINEAR_TYPES tuple, and modified _merge_lora's isinstance check.
Activity
  • No activity (comments, reviews, progress) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is a solid pull request that improves the robustness of Unsloth, particularly by making bitsandbytes an optional dependency. This is a great change for users on platforms like ROCm. The changes are implemented consistently across the codebase using try-except blocks to gracefully handle missing packages. The PR also includes several other valuable fixes, such as the backwards-compatible is_bf16_supported helper, a fix for GRPO loss calculation with left-padding, and improved compatibility with older PEFT versions. All the changes look correct and well-implemented. Great work!

@GoldenGrapeGentleman
Copy link
Copy Markdown
Contributor

W7900 (gfx1100, RDNA3) Test Report

Environment: ROCm 7.1 | PyTorch 2.8.0+rocm7.1 | Triton 3.4.0 | Transformers 4.57.6

@danielhanchen This PR is highly needed for RDNA GPU users — bitsandbytes ROCm support is still fragile on consumer cards (gfx1100/gfx1102/gfx1151). Making it optional unblocks bf16 workflows entirely.

I can confirm the following work without bitsandbytes on W7900:

Happy to run additional test cases if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.