-
Notifications
You must be signed in to change notification settings - Fork 270
Guard GPT-OSS allocator warmup on low-memory 4-bit loads #521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d66df48
054aa89
07df91b
5b1f0ff
b457eae
135d6ad
684f684
eaf1f4a
9cd8166
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -1094,13 +1094,105 @@ def patch_gpt_oss_bnb4bit_auto(): | |||||
| TEMPORARY_PATCHES.append(patch_gpt_oss_bnb4bit_auto) | ||||||
|
|
||||||
|
|
||||||
| _LOW_MEMORY_ACCELERATOR_BYTES = int(24 * 1024**3) | ||||||
|
|
||||||
|
|
||||||
| def _get_active_accelerator_index(): | ||||||
| try: | ||||||
| if DEVICE_TYPE == "xpu": | ||||||
| if hasattr(torch, "xpu") and hasattr(torch.xpu, "current_device"): | ||||||
| return int(torch.xpu.current_device()) | ||||||
| return 0 | ||||||
| if hasattr(torch, "cuda") and hasattr(torch.cuda, "current_device"): | ||||||
| return int(torch.cuda.current_device()) | ||||||
| except Exception: | ||||||
|
|
||||||
| pass | ||||||
| return 0 | ||||||
|
|
||||||
|
|
||||||
| def _get_accelerator_total_memory_bytes(): | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NIT: We should put these in utils.py (away from temporary_patches) as these have potential to be reused. |
||||||
| try: | ||||||
| device_index = _get_active_accelerator_index() | ||||||
| if DEVICE_TYPE == "xpu": | ||||||
| return int(torch.xpu.memory.mem_get_info(device_index)[-1]) | ||||||
| return int(torch.cuda.memory.mem_get_info(device_index)[-1]) | ||||||
| except Exception: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Catching a broad
Suggested change
|
||||||
| return None | ||||||
|
|
||||||
|
|
||||||
| def _get_effective_accelerator_memory_bytes(): | ||||||
| total_memory = _get_accelerator_total_memory_bytes() | ||||||
| if total_memory is None: | ||||||
| return None | ||||||
| if DEVICE_TYPE != "xpu" and hasattr(torch.cuda, "get_per_process_memory_fraction"): | ||||||
| try: | ||||||
| device_index = _get_active_accelerator_index() | ||||||
| fraction = float(torch.cuda.get_per_process_memory_fraction(device_index)) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting. Are we restricting this somewhere else or is this to respect some hypothetical user setting such a limit? |
||||||
| if 0.0 < fraction < 1.0: | ||||||
| return int(total_memory * fraction) | ||||||
| except Exception: | ||||||
|
github-code-quality[bot] marked this conversation as resolved.
Fixed
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the previous comment, catching a broad
Suggested change
|
||||||
| pass | ||||||
| return total_memory | ||||||
|
|
||||||
|
|
||||||
| def _should_skip_transformers_allocator_warmup() -> bool: | ||||||
| """ | ||||||
| Skip transformers allocator warmup on low-memory accelerators. | ||||||
|
|
||||||
| `caching_allocator_warmup` can allocate large single chunks before weights | ||||||
| are loaded, which can OOM constrained GPUs. | ||||||
| """ | ||||||
| mode = os.environ.get("UNSLOTH_ALLOCATOR_WARMUP", "").strip().lower() | ||||||
| if mode in ("off", "disable", "0", "false"): | ||||||
| return True | ||||||
| if mode in ("on", "enable", "1", "true"): | ||||||
| return False | ||||||
|
|
||||||
| total_memory = _get_effective_accelerator_memory_bytes() | ||||||
| if total_memory is None: | ||||||
| return False | ||||||
| return total_memory <= _LOW_MEMORY_ACCELERATOR_BYTES | ||||||
|
|
||||||
|
|
||||||
| def patch_transformers_caching_allocator_warmup(): | ||||||
|
github-code-quality[bot] marked this conversation as resolved.
Fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
|
||||||
| try: | ||||||
| import transformers.modeling_utils | ||||||
| except Exception as e: | ||||||
| return raise_error("transformers.modeling_utils", e) | ||||||
|
|
||||||
| warmup_fn = transformers.modeling_utils.caching_allocator_warmup | ||||||
| if hasattr(warmup_fn, "__unsloth_allocator_warmup_guarded__"): | ||||||
| return | ||||||
| # Backward compatibility with previous guard attribute. | ||||||
| if hasattr(warmup_fn, "__unsloth_gpt_oss_guarded__"): | ||||||
| return | ||||||
|
|
||||||
| def guarded_caching_allocator_warmup(model, expanded_device_map, hf_quantizer): | ||||||
|
github-code-quality[bot] marked this conversation as resolved.
Fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
|
||||||
| if _should_skip_transformers_allocator_warmup(): | ||||||
| if UNSLOTH_ENABLE_LOGGING: | ||||||
| logger.warning_once( | ||||||
| "Unsloth: Skipping transformers caching_allocator_warmup " | ||||||
| "on low-memory accelerators (<24GB effective memory). " | ||||||
| "Set UNSLOTH_ALLOCATOR_WARMUP=on to keep warmup." | ||||||
| ) | ||||||
| return | ||||||
| return warmup_fn(model, expanded_device_map, hf_quantizer) | ||||||
|
|
||||||
| guarded_caching_allocator_warmup.__unsloth_allocator_warmup_guarded__ = True | ||||||
| # Keep legacy marker so older checks still detect this as guarded. | ||||||
| guarded_caching_allocator_warmup.__unsloth_gpt_oss_guarded__ = True | ||||||
| transformers.modeling_utils.caching_allocator_warmup = guarded_caching_allocator_warmup | ||||||
|
|
||||||
|
|
||||||
| TEMPORARY_PATCHES.append(patch_transformers_caching_allocator_warmup) | ||||||
|
|
||||||
|
|
||||||
| # Combo kernels uses too much VRAM for low memory GPUs | ||||||
| from ..device_type import DEVICE_TYPE | ||||||
|
|
||||||
| if DEVICE_TYPE == "xpu": | ||||||
| device_memory = torch.xpu.memory.mem_get_info(0)[-1] | ||||||
| else: | ||||||
| device_memory = torch.cuda.memory.mem_get_info(0)[-1] | ||||||
| device_memory = _get_accelerator_total_memory_bytes() | ||||||
| if device_memory is None: | ||||||
| device_memory = 0 | ||||||
| use_combo_kernels = False if device_memory/1024/1024/1024 <= 40 else True | ||||||
| fused_torch_compile_options = get_torch_compile_options( | ||||||
| epilogue_fusion = True, | ||||||
|
|
@@ -1316,15 +1408,19 @@ def _should_use_gpt_oss_bnb4bit() -> bool: | |||||
| Default: True when load_in_4bit is active. | ||||||
| Set UNSLOTH_GPT_OSS_BNB4BIT_DISABLE=1 to force BF16 path. | ||||||
| """ | ||||||
| if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): | ||||||
| if "gpt_oss" not in _normalized_unsloth_model_name(): | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see this being a repeat of #519 (some of it). We should try to consolidate to avoid merge conflicts |
||||||
| return False | ||||||
| if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): | ||||||
| if "_load_in_4bit_" not in _normalized_unsloth_model_name(): | ||||||
| return False | ||||||
| return os.environ.get("UNSLOTH_GPT_OSS_BNB4BIT_DISABLE", "0") != "1" | ||||||
|
|
||||||
|
|
||||||
| def _is_gpt_oss_4bit_load() -> bool: | ||||||
| return "_load_in_4bit_" in os.environ.get("UNSLOTH_MODEL_NAME", "") | ||||||
| return "_load_in_4bit_" in _normalized_unsloth_model_name() | ||||||
|
|
||||||
|
|
||||||
| def _normalized_unsloth_model_name() -> str: | ||||||
| return os.environ.get("UNSLOTH_MODEL_NAME", "").replace("-", "_") | ||||||
|
|
||||||
|
|
||||||
| def _is_transformers_v5() -> bool: | ||||||
|
|
@@ -1340,7 +1436,7 @@ def patch_gpt_oss_moe_for_lora(): | |||||
| IMPORTANT: We only patch the forward method, NOT replace the entire class. | ||||||
| This preserves the original class structure so weights load correctly. | ||||||
| """ | ||||||
| if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): | ||||||
| if "gpt_oss" not in _normalized_unsloth_model_name(): | ||||||
| return | ||||||
| if _is_gpt_oss_4bit_load() or _should_use_gpt_oss_bnb4bit(): | ||||||
| # 4-bit loads should keep quantized weights and use default PEFT LoRA. | ||||||
|
|
@@ -1774,8 +1870,8 @@ def patch_gpt_oss_linearized(): | |||||
| Patch GPT OSS for 4bit loading with grouped_mm support. | ||||||
| Only patches the GptOssExperts forward method - keeps original classes for proper weight loading. | ||||||
| """ | ||||||
| if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return | ||||||
| if "_load_in_4bit_" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return | ||||||
| if "gpt_oss" not in _normalized_unsloth_model_name(): return | ||||||
| if "_load_in_4bit_" not in _normalized_unsloth_model_name(): return | ||||||
| if _should_use_gpt_oss_bnb4bit(): return | ||||||
| try: | ||||||
| import transformers.models.gpt_oss.modeling_gpt_oss | ||||||
|
|
@@ -1813,7 +1909,7 @@ def experts_forward( | |||||
|
|
||||||
| def patch_GptOssAttention(): | ||||||
| if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0": return | ||||||
| if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return | ||||||
| if "gpt_oss" not in _normalized_unsloth_model_name(): return | ||||||
| try: | ||||||
| from ..flex_attention import ( | ||||||
| flex_attention_with_sink, | ||||||
|
|
@@ -2054,7 +2150,7 @@ def forward( | |||||
|
|
||||||
| def patch_GptOssModel(): | ||||||
| if os.environ.get("UNSLOTH_ENABLE_FLEX_ATTENTION", "1") == "0": return | ||||||
| if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): return | ||||||
| if "gpt_oss" not in _normalized_unsloth_model_name(): return | ||||||
| try: | ||||||
| import transformers.models.gpt_oss.modeling_gpt_oss | ||||||
| transformers.models.gpt_oss.modeling_gpt_oss.GptOssModel | ||||||
|
|
@@ -2075,12 +2171,25 @@ def patch_GptOssModel(): | |||||
| import transformers.generation.utils | ||||||
| def wrap(f): | ||||||
| def return_attention_mask(*args, **kwargs): | ||||||
| if kwargs["input_embeds"].requires_grad: | ||||||
| input_embeds = kwargs.get("input_embeds", None) | ||||||
| if input_embeds is None: | ||||||
| input_embeds = kwargs.get("inputs_embeds", None) | ||||||
| if input_embeds is None: | ||||||
| for arg in args: | ||||||
| if type(arg) is torch.Tensor and arg.is_floating_point(): | ||||||
| input_embeds = arg | ||||||
| break | ||||||
|
|
||||||
| if input_embeds is not None and input_embeds.requires_grad: | ||||||
| if "attention_mask" in kwargs: | ||||||
| return kwargs["attention_mask"] | ||||||
| for arg in args: | ||||||
| if type(arg) is torch.Tensor and arg.dtype == torch.int32: | ||||||
| if ( | ||||||
| type(arg) is torch.Tensor and | ||||||
| arg.dtype in (torch.int32, torch.int64, torch.bool) | ||||||
| ): | ||||||
| return arg | ||||||
| return f(*args, **kwargs) | ||||||
| else: | ||||||
| # Eager | ||||||
| return f(*args, **kwargs) | ||||||
|
|
@@ -2739,7 +2848,7 @@ def patch_gpt_oss_config(): | |||||
|
|
||||||
|
|
||||||
| def patch_gpt_oss_init_weights_modulelist_fix(): | ||||||
| if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): | ||||||
| if "gpt_oss" not in _normalized_unsloth_model_name(): | ||||||
| return | ||||||
| try: | ||||||
| import transformers.models.gpt_oss.modeling_gpt_oss | ||||||
|
|
@@ -2784,7 +2893,7 @@ def patch_gpt_oss_for_grpo(): | |||||
| When UNSLOTH_RETURN_HIDDEN_STATES=1, return hidden_states instead of logits. | ||||||
| This fixes the matrix multiplication dimension mismatch issue in GRPO training. | ||||||
| """ | ||||||
| if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): | ||||||
| if "gpt_oss" not in _normalized_unsloth_model_name(): | ||||||
| return | ||||||
|
|
||||||
| try: | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually nice. Never noticed this. Seems like a very recent change.
But this function is to strip the decorator. Why not have the decorator and use the kernel?