Skip to content

introduce device_context to simplify code.#3875

Closed
ykaitao wants to merge 3437 commits into
unslothai:mainfrom
ykaitao:ktyang_device_context
Closed

introduce device_context to simplify code.#3875
ykaitao wants to merge 3437 commits into
unslothai:mainfrom
ykaitao:ktyang_device_context

Conversation

@ykaitao

@ykaitao ykaitao commented Jan 11, 2026

Copy link
Copy Markdown
Contributor

No description provided.

numb3r33 and others added 30 commits December 24, 2025 01:34
Use regex to dynamically detect and preserve the original indentation
when replacing the 'return output' statement, instead of hardcoding
spaces. This ensures the patched code maintains consistent indentation
regardless of the original formatting.
Replace f-string triple-quoted approach with explicit newline characters
for clearer string construction in the grpo_trainer patch.
* Add missing import of inspect

* Update device_type.py
…nslothai#3768)

* Improve error message for fast_inference and full_finetuning

* Refine error message string formatting

* Update unsloth/models/vision.py

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
…nslothai#3780)

* fix(trainer): import psutil to prevent NameError in _prepare_dataset

Fixes unslothai#3777

* Update rl.py

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: Francesco Bertolotti <francesco.bertolotti@igenius.ai>
* Guard optional trl.experimental.openenv usage in RL patches

* Simplify optional trl.openenv import handling

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…3790)

* Fix is_contiguous() method call and remove duplicate imports

- Fix bug in rope_embedding.py where is_contiguous was used without
  parentheses, causing the method object (always truthy) to be evaluated
  instead of calling the method. This fixes issue unslothai#3781 where fast rope
  backpropagation was broken for zero strided/non-contiguous tensors.

- Remove duplicate `import torch` in rl.py (lines 20 and 25)
- Remove duplicate `import functools` and `import types` in vision.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Fix Boolean value of Tensor ambiguity error in mistral.py

Replace `or` operator with explicit `is None` check when getting
n_items from kwargs. The `or` operator fails when the value is a
Tensor because Python cannot determine the boolean value of a
multi-element tensor.

Fixes unslothai#3766

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* Update rope_embedding.py

---------

Co-authored-by: yurekami <yurekami@users.noreply.github.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
…lothai#3794)

Add "corda" as an allowed value for the init_lora_weights parameter
in FastLanguageModel.get_peft_model() and FastBaseModel.get_peft_model().

This enables users to use CorDA (Correlation-aware Decomposed Adaptation)
initialization from PEFT, which provides an alternative LoRA initialization
strategy for improved finetuning performance.

Fixes unslothai#3693

Signed-off-by: majiayu000 <1835304752@qq.com>
…lothai#3811)

* Fix correctness bugs in rl.py, rl_replacements.py, and vision.py

1. rl_replacements.py (lines 864, 870): Fixed undefined `nanmin`/`nanmax`
   functions by using `.nan_to_num(nan=inf/-inf).min()/.max()` pattern.
   PyTorch doesn't have torch.nanmin/nanmax, so we replace NaN values
   before computing min/max.

2. vision.py (line 150): Fixed bug where code checked for "input" key
   but then accessed kwargs["input_ids"] instead of kwargs["input"].

3. vision.py (line 159): Fixed bug where literal string "key" was used
   instead of the variable `key` when accessing kwargs.

4. rl.py (lines 903, 905): Fixed non-existent `MathError` exception
   by replacing with `ValueError`.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1. cohere.py:347-348 - Fixed wrong variable names in QK normalization.
   Used `Q`/`K` but variables were named `Qn`/`Kn`. This caused NameError
   when `use_qk_norm=True` (e.g., c4ai-command-r-plus models).

2. cohere.py:482 - Fixed wrong object reference in inference loop.
   Used `self.mlp` but should be `decoder_layer.mlp` since we're
   iterating through decoder layers. Caused AttributeError during inference.

3. falcon_h1.py:459,461 - Fixed wrong attribute names in inference path.
   Used `post_attention_layernorm` and `mlp` but Falcon H1 uses
   `pre_ff_layernorm` and `feed_forward`. Caused AttributeError during generation.

4. qwen3_moe.py:210 - Fixed wrong module path with incorrect capitalization.
   Used `transformers.models.Qwen3Moe` but should be `transformers.models.qwen3_moe`.
   Caused AttributeError when patching rotary embeddings.

5. qwen3_moe.py:239 - Fixed wrong model_patcher class.
   Used `FastQwen3Model` but should be `FastQwen3MoeModel` for MoE models.
   Caused incorrect patching for Qwen3 MoE models.

6. hf_hub.py:21-22 - Fixed floor division and missing return for billion values.
   Used `//` instead of `/` for millions, and had no return for values >= 1B.
   Caused incorrect formatting and None return for large numbers.

7. save.py:550 - Fixed self-assignment that did nothing.
   `sharded_ram_usage = sharded_ram_usage` should be `= max_shard_size`.
   Caused integer shard sizes to be ignored.

8. rl.py:562-567 - Fixed orphan string not included in length_check.
   The elif branch for max_seq_length validation was a standalone string
   expression, not concatenated to length_check. Caused silent skip of
   the max_seq_length > model_max_seq_length warning.

9. granite.py:49-52 - Fixed wrong model name and version in error message.
   Said "Gemma2" and "4.42.3" but should be "Granite" and "4.45.0".
…tmul

Fix 3D tensor support for bitsandbytes 8-bit matmul in forward pass
🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
FIX: weight tying for LoRA embeddings and lm_head
Gemma3 models have a large vocabulary (262144 tokens) which causes
training loss to explode when using int8 embedding quantization.

This fix auto-detects Gemma3 models and switches from int8-int4
(phone-deployment) to int4 weight-only QAT for stable training.
…lity

Fix Gemma3 QAT training instability with int8-int4 scheme
When users load a model with fast_inference=False but then try to use
vLLM-style arguments with fast_generate, they previously got confusing
errors. This adds a wrapper that detects common mistakes and provides
helpful guidance:

- Using sampling_params: explains to use HF generate args instead
- Using lora_request: explains LoRA weights are already merged
- Passing text strings: shows how to tokenize input first

Changes:
- Add make_fast_generate_wrapper to _utils.py
- Apply wrapper in llama.py when fast_inference=False
- Apply wrapper in vision.py when fast_inference=False
GoldenGrapeGentleman and others added 20 commits March 1, 2026 00:15
…ads (unslothai#4026)

Fix global dequantize buffer dtype mismatch when loading multiple 4-bit models with different dtypes in the same process. Adds dtype check alongside existing None check for WEIGHT_BUFFER in both CUDA/HIP and XPU paths.
…#4034)

Co-authored-by: Daniel Hanchen <danielhanchen@users.noreply.github.com>
)

* Fix auto padding free logic to respect user passed

* Update unsloth/trainer.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
* Add Qwen3.5 to FORCE_FLOAT32

* fix vision encoder dtype mismatch

* revert vision cast changes
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.15.2 → v0.15.4](astral-sh/ruff-pre-commit@v0.15.2...v0.15.4)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Updated with Qwen3.5 Small models
…nslothai#4136)

Current arch.startswith("gfx1") incorrectly matches:
  - RDNA1 (gfx10xx) and RDNA2 (gfx103x): not ROCm supported
  - gfx1102 (RX 7600), gfx1103 (Phoenix APU): not in ROCm support matrix
  - gfx1150/1151/1152 (RDNA3.5 APUs): not in ROCm support matrix

Replace with explicit whitelist aligned to the ROCm Linux support matrix:
  https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html

  gfx1100 - RDNA3 discrete (RX 7900 series, PRO W7900/W7800)
  gfx1101 - RDNA3 discrete (RX 7800/7700 series, PRO W7700)
  gfx1200 - RDNA4 discrete (RX 9060 series)
  gfx1201 - RDNA4 discrete (RX 9070 series, AI PRO R9700)

Mirrors the existing is_cdna() pattern. Avoids silently applying
unverified Triton kernel tuning to unsupported hardware.
* Fix lm_head lora save

* Fix _need_to_train_embeddings guard for lm_head LoRA targets

When lm_head is already in final_modules as a LoRA target, the
_need_to_train_embeddings block should not also add it to
modules_to_save. This prevents dual-wrapping (LoRA + modules_to_save
on the same module) which causes assertion failures downstream.

Check if embed_tokens/lm_head are already being trained as LoRA
targets before adding them to modules_to_save. Also prevents
duplicate entries with elif guards.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* add intel support for torch210

* fix for typo
…support (unslothai#4138)

* fix: update GGUF save paths to use ~/.unsloth/llama.cpp with Windows support

* fix: quote LLAMA_CPP_DEFAULT_DIR in fallback shell commands to handle paths with spaces

* refactor: deduplicate platform-specific build instructions in quantization error message

* chore: remove accidentally committed PR description file

* Fix import safety and f-string bugs in save.py

- H4: Add defensive try/except for LLAMA_CPP_DEFAULT_DIR and IS_WINDOWS imports
  with fallback defaults, so save.py works even if zoo PR unslothai#526 is not merged yet
- H5: Fix Kaggle error path using plain "Error: {e}" instead of f"Error: {e}",
  so the actual exception is shown to users

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Datta Nimmaturi <venkatadattasainimmaturi@gmail.com>
Co-authored-by: Daniel Han <danielhanchen@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Fixup mapper issues and resolve properly

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* Fix broken wandb import crashing unsloth startup

When wandb is installed but broken (e.g., wandb < 0.19.11 with
protobuf >= 6.0), the import chain unsloth -> trl -> transformers ->
is_wandb_available() -> import wandb crashes with:

  ImportError: cannot import name 'Imports' from
  'wandb.proto.wandb_telemetry_pb2'

This happens because transformers' is_wandb_available() has no
try/except around `import wandb`. The error propagates up and kills
`from unsloth import FastLanguageModel` even though wandb is optional.

Add disable_broken_wandb() following the same pattern as
disable_torchcodec_if_broken(). It proactively tries importing wandb
during early init, and if the import fails, patches
is_wandb_available() to return False and sets WANDB_DISABLED=true.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
…slothai#4148)

trl/trainer/callbacks.py imports is_wandb_available from
accelerate.utils, not from transformers. The original fix in unslothai#4147
only patched the transformers version, so `from trl import GRPOTrainer`
still crashed via the callbacks.py -> accelerate -> wandb path.

Must patch both the source module (accelerate.utils.imports) AND the
re-export namespace (accelerate.utils) since Python's
`from accelerate.utils import X` reads from the latter, which holds
its own cached reference.
@ykaitao ykaitao force-pushed the ktyang_device_context branch from 0e62bc7 to bf557e6 Compare March 6, 2026 07:08
@ykaitao ykaitao requested review from Datta0 and mmathew23 as code owners March 6, 2026 07:08
@ykaitao

ykaitao commented Mar 6, 2026

Copy link
Copy Markdown
Contributor Author

Hi Team, I have resolved the conflicts. @mmathew23 @Datta0 @danielhanchen

@ykaitao

ykaitao commented Mar 6, 2026

Copy link
Copy Markdown
Contributor Author

/gemini review

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: bf557e64b9

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread unsloth/device_type.py
"""Encapsulates device-specific operations for XPU/HIP/CUDA."""

def __init__(self, device_type: str = DEVICE_TYPE) -> None:
DEVICE_MODULE_MAP = {"xpu": torch.xpu, "cuda": torch.cuda, "hip": torch.cuda}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid eager torch.xpu lookup in device module map

Constructing DEVICE_MODULE_MAP with "xpu": torch.xpu eagerly dereferences torch.xpu even when running on CUDA/HIP, so CUDA-only PyTorch builds that do not expose an xpu attribute will fail during import with AttributeError before device_type is checked. This is a startup regression (the file already treats xpu as optional via hasattr(torch, "xpu") in get_device_type()), and it can block all model initialization on those environments.

Useful? React with 👍 / 👎.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a DeviceContext class to encapsulate and simplify device-specific operations, which is a great improvement for code clarity and maintainability. The changes in llama.py and vision.py effectively leverage this new class, removing duplicated code. My review includes a few suggestions to further improve the DeviceContext class by moving method-level dictionaries to class-level constants for better organization, and to make exception handling more specific.

Comment thread unsloth/device_type.py
Comment on lines +134 to +161
class DeviceContext:
"""Encapsulates device-specific operations for XPU/HIP/CUDA."""

def __init__(self, device_type: str = DEVICE_TYPE) -> None:
DEVICE_MODULE_MAP = {"xpu": torch.xpu, "cuda": torch.cuda, "hip": torch.cuda}
if device_type not in DEVICE_MODULE_MAP:
raise ValueError(f"Unsloth: Unsupported device type: {device_type}")
self.device_type = device_type
# Cache the torch module for this device
self.torch_module = DEVICE_MODULE_MAP[device_type]

def get_stats(self) -> tuple[str, str, float]:
"""Return (name, stats_snippet, max_memory_gb)."""
gpu_stats = self.torch_module.get_device_properties(0)
max_mem = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)

# Device name
name = gpu_stats.name + ". " if gpu_stats.name else self._get_default_name()

# Toolkit snippet
snippet = self._get_toolkit_snippet(gpu_stats)

return name, snippet, max_mem

def _get_default_name(self) -> str:
"""Get default device name when props.name is empty."""
names = {"xpu": "Intel XPU", "cuda": "NVIDIA GPU", "hip": "AMD GPU"}
return names[self.device_type] + " Device. "

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better code organization and to avoid re-creating dictionaries on each method call, it's a good practice to define DEVICE_MODULE_MAP and names as class-level constants.

class DeviceContext:
    """Encapsulates device-specific operations for XPU/HIP/CUDA."""
    DEVICE_MODULE_MAP = {"xpu": torch.xpu, "cuda": torch.cuda, "hip": torch.cuda}
    _DEFAULT_NAMES = {"xpu": "Intel XPU", "cuda": "NVIDIA GPU", "hip": "AMD GPU"}

    def __init__(self, device_type: str = DEVICE_TYPE) -> None:
        if device_type not in self.DEVICE_MODULE_MAP:
            raise ValueError(f"Unsloth: Unsupported device type: {device_type}")
        self.device_type = device_type
        # Cache the torch module for this device
        self.torch_module = self.DEVICE_MODULE_MAP[device_type]

    def get_stats(self) -> tuple[str, str, float]:
        """Return (name, stats_snippet, max_memory_gb)."""
        gpu_stats = self.torch_module.get_device_properties(0)
        max_mem = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)

        # Device name
        name = gpu_stats.name + ". " if gpu_stats.name else self._get_default_name()

        # Toolkit snippet
        snippet = self._get_toolkit_snippet(gpu_stats)

        return name, snippet, max_mem

    def _get_default_name(self) -> str:
        """Get default device name when props.name is empty."""
        return self._DEFAULT_NAMES[self.device_type] + " Device. "

Comment thread unsloth/models/llama.py
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
try:
vllm_version = f" vLLM: {importlib_version('vllm')}."
except:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's generally better to catch specific exceptions rather than using a bare except:. This avoids accidentally catching other unexpected errors like KeyboardInterrupt or SystemExit. Using except Exception: is a safer alternative.

Suggested change
except:
except Exception:

Comment thread unsloth/models/vision.py
gpu_stats_name, gpu_stats_snippet, max_memory = device_context.get_stats()
try:
vllm_version = f" vLLM: {importlib_version('vllm')}."
except:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's generally better to catch specific exceptions rather than using a bare except:. This avoids accidentally catching other unexpected errors like KeyboardInterrupt or SystemExit. Using except Exception: is a safer alternative.

Suggested change
except:
except Exception:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.