Guard GPT-OSS allocator warmup on low-memory 4-bit loads#521
Guard GPT-OSS allocator warmup on low-memory 4-bit loads#521danielhanchen wants to merge 9 commits into
Conversation
Summary of ChangesHello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a critical enhancement to prevent Out-Of-Memory (OOM) errors during GPT-OSS 4-bit model loading on systems with limited GPU memory. By intelligently detecting available accelerator memory and conditionally skipping a large allocator warmup, it ensures models can load successfully in constrained environments. This change improves robustness without impacting performance on high-memory setups and provides user control over the new behavior. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a mechanism to guard the transformers.modeling_utils.caching_allocator_warmup for GPT-OSS 4-bit loads on low-memory GPUs. The changes are well-structured and address the OOM issue described. The addition of an environment variable for override is a good feature. My main feedback is to refactor one of the new functions to improve readability and maintainability by extracting magic values into constants. Overall, this is a solid improvement.
| model_name = os.environ.get("UNSLOTH_MODEL_NAME", "").replace("-", "_") | ||
| if "gpt_oss" not in model_name: | ||
| return False | ||
| if "_load_in_4bit_" not in model_name: | ||
| return False | ||
|
|
||
| mode = os.environ.get("UNSLOTH_GPT_OSS_ALLOCATOR_WARMUP", "auto").strip().lower() | ||
| if mode in ("off", "disable", "0", "false"): | ||
| return True | ||
| if mode in ("on", "enable", "1", "true"): | ||
| return False | ||
|
|
||
| total_memory = _get_effective_accelerator_memory_bytes() | ||
| if total_memory is None: | ||
| return False | ||
| return total_memory <= int(20 * 1024**3) |
There was a problem hiding this comment.
This function can be refactored for better readability and maintainability:
- The checks for
model_namecan be combined into a singleifstatement for conciseness. - The magic strings for
modechecking and the magic number for the memory threshold can be extracted into constants within the function's scope. Using sets for mode checking is also slightly more efficient for lookups. - Adding comments to explain the logic for "auto" mode and the memory threshold would improve clarity for future maintainers.
Here is a suggested refactoring:
model_name = os.environ.get("UNSLOTH_MODEL_NAME", "").replace("-", "_")
if "gpt_oss" not in model_name or "_load_in_4bit_" not in model_name:
return False
mode = os.environ.get("UNSLOTH_GPT_OSS_ALLOCATOR_WARMUP", "auto").strip().lower()
_WARMUP_OFF_MODES = {"off", "disable", "0", "false"}
if mode in _WARMUP_OFF_MODES:
return True
_WARMUP_ON_MODES = {"on", "enable", "1", "true"}
if mode in _WARMUP_ON_MODES:
return False
# Auto mode: skip on low memory devices
total_memory = _get_effective_accelerator_memory_bytes()
if total_memory is None:
return False
# 20 GiB threshold for low-memory devices
_LOW_MEMORY_THRESHOLD_BYTES = 20 * 1024**3
return total_memory <= _LOW_MEMORY_THRESHOLD_BYTESThere was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: d66df48dba
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| try: | ||
| if DEVICE_TYPE == "xpu": | ||
| return int(torch.xpu.memory.mem_get_info(0)[-1]) | ||
| return int(torch.cuda.memory.mem_get_info(0)[-1]) |
There was a problem hiding this comment.
Detect memory on the active accelerator, not device 0
The auto-skip decision always reads memory from accelerator index 0, so in multi-GPU or non-default-device runs it can evaluate the wrong device and keep warmup enabled even when the actual target device is low-memory/capped. In particular, if the model is loaded on cuda:1 (or another mapped device), this guard may miss the OOM condition it was added to prevent. The check should use the current/target device (or derive the relevant indices from expanded_device_map) instead of hard-coding 0.
Useful? React with 👍 / 👎.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a valuable safeguard against out-of-memory errors on devices with limited VRAM by conditionally skipping the allocator warmup. The logic for detecting low-memory environments is sound, and providing an environment variable for override is a good practice. Additionally, the refactoring to normalize the model name from the environment variable improves code consistency and fixes a potential bug with hyphenated model names. The changes are well-structured and enhance the robustness of the patching mechanism. I have a few minor suggestions to improve maintainability.
| if DEVICE_TYPE == "xpu": | ||
| return int(torch.xpu.memory.mem_get_info(0)[-1]) | ||
| return int(torch.cuda.memory.mem_get_info(0)[-1]) | ||
| except Exception: |
There was a problem hiding this comment.
Catching a broad Exception can hide unexpected errors and make debugging more difficult. It's better to catch more specific exceptions that you expect to handle, such as RuntimeError, ImportError, or AttributeError in this context. This makes the code's intent clearer and more robust against unrelated issues.
| except Exception: | |
| except (RuntimeError, ImportError, AttributeError): |
| fraction = float(torch.cuda.get_per_process_memory_fraction(0)) | ||
| if 0.0 < fraction < 1.0: | ||
| return int(total_memory * fraction) | ||
| except Exception: |
There was a problem hiding this comment.
Similar to the previous comment, catching a broad Exception is discouraged. For torch.cuda.get_per_process_memory_fraction, it's better to specifically handle expected exceptions like RuntimeError or NotImplementedError to avoid masking other potential bugs.
| except Exception: | |
| except (RuntimeError, NotImplementedError): |
|
|
||
| def _should_skip_transformers_allocator_warmup() -> bool: | ||
| """ | ||
| Skip transformers allocator warmup on low-memory accelerators. | ||
|
|
||
| `caching_allocator_warmup` can allocate large single chunks before weights | ||
| are loaded, which can OOM constrained GPUs. | ||
| """ | ||
| mode = os.environ.get("UNSLOTH_ALLOCATOR_WARMUP", "").strip().lower() | ||
| if mode in ("off", "disable", "0", "false"): | ||
| return True | ||
| if mode in ("on", "enable", "1", "true"): | ||
| return False | ||
|
|
||
| total_memory = _get_effective_accelerator_memory_bytes() | ||
| if total_memory is None: | ||
| return False | ||
| return total_memory <= int(24 * 1024**3) |
There was a problem hiding this comment.
To improve readability and maintainability, it's a good practice to define magic numbers like the memory threshold as a named constant. This makes the code easier to understand and modify in the future.
| def _should_skip_transformers_allocator_warmup() -> bool: | |
| """ | |
| Skip transformers allocator warmup on low-memory accelerators. | |
| `caching_allocator_warmup` can allocate large single chunks before weights | |
| are loaded, which can OOM constrained GPUs. | |
| """ | |
| mode = os.environ.get("UNSLOTH_ALLOCATOR_WARMUP", "").strip().lower() | |
| if mode in ("off", "disable", "0", "false"): | |
| return True | |
| if mode in ("on", "enable", "1", "true"): | |
| return False | |
| total_memory = _get_effective_accelerator_memory_bytes() | |
| if total_memory is None: | |
| return False | |
| return total_memory <= int(24 * 1024**3) | |
| _LOW_MEMORY_THRESHOLD_BYTES = 24 * 1024**3 | |
| def _should_skip_transformers_allocator_warmup() -> bool: | |
| """ | |
| Skip transformers allocator warmup on low-memory accelerators. | |
| `caching_allocator_warmup` can allocate large single chunks before weights | |
| are loaded, which can OOM constrained GPUs. | |
| """ | |
| mode = os.environ.get("UNSLOTH_ALLOCATOR_WARMUP", "").strip().lower() | |
| if mode in ("off", "disable", "0", "false"): | |
| return True | |
| if mode in ("on", "enable", "1", "true"): | |
| return False | |
| total_memory = _get_effective_accelerator_memory_bytes() | |
| if total_memory is None: | |
| return False | |
| return total_memory <= _LOW_MEMORY_THRESHOLD_BYTES |
|
Triage update:
Implemented in commit |
| return total_memory <= _LOW_MEMORY_ACCELERATOR_BYTES | ||
|
|
||
|
|
||
| def patch_transformers_caching_allocator_warmup(): |
| if hasattr(warmup_fn, "__unsloth_gpt_oss_guarded__"): | ||
| return | ||
|
|
||
| def guarded_caching_allocator_warmup(model, expanded_device_map, hf_quantizer): |
| return 0 | ||
| if hasattr(torch, "cuda") and hasattr(torch.cuda, "current_device"): | ||
| return int(torch.cuda.current_device()) | ||
| except Exception: |
| fraction = float(torch.cuda.get_per_process_memory_fraction(device_index)) | ||
| if 0.0 < fraction < 1.0: | ||
| return int(total_memory * fraction) | ||
| except Exception: |
|
Additional follow-up on #521 before final matrix runs:
Commit included on this branch: |
|
What are you trying which gave grads like |
| return 0 | ||
|
|
||
|
|
||
| def _get_accelerator_total_memory_bytes(): |
There was a problem hiding this comment.
NIT: We should put these in utils.py (away from temporary_patches) as these have potential to be reused.
Also there was a PR about DeviceContext iirc which might better handle this
Here it is unslothai/unsloth#3875
| if DEVICE_TYPE != "xpu" and hasattr(torch.cuda, "get_per_process_memory_fraction"): | ||
| try: | ||
| device_index = _get_active_accelerator_index() | ||
| fraction = float(torch.cuda.get_per_process_memory_fraction(device_index)) |
There was a problem hiding this comment.
Interesting. Are we restricting this somewhere else or is this to respect some hypothetical user setting such a limit?
| Set UNSLOTH_GPT_OSS_BNB4BIT_DISABLE=1 to force BF16 path. | ||
| """ | ||
| if "gpt_oss" not in os.environ.get("UNSLOTH_MODEL_NAME", ""): | ||
| if "gpt_oss" not in _normalized_unsloth_model_name(): |
There was a problem hiding this comment.
I see this being a repeat of #519 (some of it). We should try to consolidate to avoid merge conflicts
| logger.info(f'Unsloth: stripped use_experts_implementation decorator from {module}') | ||
| if ( | ||
| "use_experts_implementation" in stripped | ||
| or "use_kernel_forward_from_hub" in stripped |
There was a problem hiding this comment.
Actually nice. Never noticed this. Seems like a very recent change.
But this function is to strip the decorator. Why not have the decorator and use the kernel?
Summary
transformers.modeling_utils.caching_allocator_warmupin Unsloth patch flow.< 24 GiB) across model loads.UNSLOTH_ALLOCATOR_WARMUP=on|off|autogpt-oss->gpt_oss) consistently.UNSLOTH_GPT_OSS_ALLOCATOR_WARMUPalias support.0) for memory and per-process-fraction checks.@use_kernel_forward_from_hub(...)and@auto_docstringdecorators in class rewrite path to remove unknown-decorator warning noise.Why
caching_allocator_warmupcan allocate a large single chunk before weights load. On low-memory setups this can OOM before model load completes.UNSLOTH_MODEL_NAMEsubstring checks in multiple places. Hyphenated names can skip patch paths that should apply for 4-bit GPT-OSS loads.0.Validation
Run dir:
logs/memory_investigation_20260226_000809/warmup_24gb_global_20260226_014122On
Unsloth 2026.2.1/transformers 4.56.2:cap16_auto_unsloth2026_generic_names.log: PASSGPU_RESERVED_GB_AFTER_LOAD 11.85GPU_PEAK_RESERVED_GB_AFTER_LOAD 12.932cap16_force_on_unsloth2026_generic_names.log: FAIL (expected control)OutOfMemoryErrorwithTried to allocate 18.93 GiBtest_llama_compile_smoke.log: PASSHyphen-guard/load checks:
logs/hyphen_guard_verify_20260226_020925/repro_uninitialized_after_local_patch.log: PASS (LOAD_OK)Version matrix (
transformers==4.56.2):logs/version_matrix_gpt_oss_20260226_020432old_2025_12(unsloth==2025.12.10,unsloth-zoo==2025.12.8):load_ok=true,train_ok=true,exception_type=nulllosses=[1.165579, 4.171052, 3.063763]grad_norms=[2.650939, inf, 1.8630026331731638e+17]new_2026_2_1(unsloth==2026.2.1,unsloth-zoo==2026.2.1):load_ok=true,train_ok=true,exception_type=nulllosses=[1.156882, 4.137518, 3.050902]grad_norms=[2.590222, 5.900850786987662e+17, 9147114459174.867]Sanity check for active-device memory helpers:
_get_active_accelerator_indexpresent_get_accelerator_total_memory_bytes/_get_effective_accelerator_memory_bytesreturn valuesNotes
auto).2025.12or2026.2.1; both loaded successfully.