Skip to content

upgrade transformers==4.55.1 and bitsandbytes==0.47.0#3064

Merged
winglian merged 5 commits into
mainfrom
transformers-upgrade
Aug 13, 2025
Merged

upgrade transformers==4.55.1 and bitsandbytes==0.47.0#3064
winglian merged 5 commits into
mainfrom
transformers-upgrade

Conversation

@winglian

@winglian winglian commented Aug 13, 2025

Copy link
Copy Markdown
Collaborator

Summary by CodeRabbit

  • New Features

    • None
  • Bug Fixes

    • Improved runtime compatibility across Transformers versions to avoid import-time failures and stabilize sliding-window flash-attention behavior.
  • Chores

    • Updated dependencies: bitsandbytes → 0.47.0 and transformers → 4.55.1 (macOS packaging preserved).
    • Switched a Docker package install to a released PyPI version and added build-skip env var for non-CUDA builds.
  • Refactor

    • Reworked FSDP initialization to natively support 4-bit (BitsAndBytes) parameters and removed prior global torch-function interception.
  • Tests

    • Test suite simplified to focus on FSDP init patch verification; removed legacy torch_function-related tests.

@winglian winglian marked this pull request as ready for review August 13, 2025 16:28
@coderabbitai

coderabbitai Bot commented Aug 13, 2025

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

Bumps two dependency pins in requirements.txt; extends and hardens flash-attention import fallbacks and default behavior; removes BitsAndBytes torch_function patching and replaces it with FSDPParam init-time patches; removes invocation of the BnB torch-function patch; narrows related end-to-end tests accordingly. No public API signatures changed.

Changes

Cohort / File(s) Summary of changes
Dependencies
requirements.txt
Bumped bitsandbytes 0.46.1 → 0.47.0 and transformers 4.55.0 → 4.55.1; no other dependency additions/removals.
Ring-attention import fallbacks
src/axolotl/monkeypatch/ring_attn/adapters/batch.py, src/axolotl/monkeypatch/ring_attn/patch.py
Make import of _flash_supports_window robust: try original symbol, then _flash_supports_window_size as alias, else default to True; add pylint duplicate-code disable on outer try.
Validation logic
src/axolotl/utils/schemas/validation.py
In check_context_parallel_size, when ring_flash_attn is available, set _flash_supports_window = True unconditionally and inject is_flash_attn_greater_or_equal into transformers.modeling_flash_attention_utils via sys.modules/setattr.
Patch manager changes
src/axolotl/loaders/patch_manager.py
Remove import and invocation of apply_bnb_torch_function_patch from the FSDP2+BnB patch path; only apply_init_sharded_param_patch and apply_init_unsharded_param_patch are applied.
FSDP2 QLoRA monkeypatch
src/axolotl/monkeypatch/fsdp2_qlora.py
Remove global torch_function-based BitsAndBytes patching and related helpers; introduce dynamic in-place patches of FSDPParam._init_sharded_param and FSDPParam.init_unsharded_param to handle Params4bit by constructing bnb.nn.modules.Params4bit or falling back to nn.Parameter. Deleted functions: patched_torch_function, apply_bnb_torch_function_patch.
Tests — e2e patched
tests/e2e/patched/test_fsdp2_qlora.py
Removed BnB/torch_function test scaffolding and fixture; added module docstring and renamed test to focus solely on verifying FSDP init patches (test_fsdp2_init_patches), removing assertions about torch_function behavior and Params4bit constructor attributes.
Docker base
docker/Dockerfile-base
Replace git-based install of causal-conv1d with PyPI causal_conv1d==1.5.2 and set CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE before install.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~35 minutes

Possibly related PRs

Suggested labels

ready to merge

Suggested reviewers

  • djsaunde
  • NanoCode012
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch transformers-upgrade

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@github-actions

github-actions Bot commented Aug 13, 2025

Copy link
Copy Markdown
Contributor

📖 Documentation Preview: https://689cf6e1e1a684da77c5820c--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit e6d722f

@winglian

Copy link
Copy Markdown
Collaborator Author

converting this PR from a draft to ready seems to still skip the multigpu CI. manual run here: https://github.com/axolotl-ai-cloud/axolotl/actions/runs/16943460540

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🔭 Outside diff range comments (2)
src/axolotl/monkeypatch/ring_attn/adapters/batch.py (1)

121-129: Use the runtime helper to decide sliding window support.

This ensures correct behavior across transformers versions (attr vs func) and picks up any runtime overrides.

-        use_sliding_windows = (
-            _flash_supports_window
-            and sliding_window is not None
-            and key_states.shape[1] > sliding_window
-        )
+        use_sliding_windows = (
+            sliding_window is not None
+            and key_states.shape[1] > sliding_window
+            and _flash_window_supported(sliding_window)
+        )
src/axolotl/monkeypatch/ring_attn/patch.py (1)

90-99: Use the runtime helper for sliding-window gating.

Prevents accidental enabling of sliding windows on unsupported FA2 versions.

-        use_sliding_windows = (
-            _flash_supports_window
-            and sliding_window is not None
-            and key_states.shape[1] > sliding_window
-        )
+        use_sliding_windows = (
+            sliding_window is not None
+            and key_states.shape[1] > sliding_window
+            and _flash_window_supported(sliding_window)
+        )
🧹 Nitpick comments (1)
src/axolotl/utils/schemas/validation.py (1)

1261-1265: Minor: fix error message formatting.

Unbalanced quotes/backticks around the install hints.

-                    "Please install it with `pip install axolotl[ring-flash-attn] 
-                    "or `pip install ring-flash-attn>=0.1.4`."
+                    "Please install it with `pip install 'axolotl[ring-flash-attn]'` "
+                    "or `pip install 'ring-flash-attn>=0.1.4'`."
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e0a2523 and 50047f5.

📒 Files selected for processing (4)
  • requirements.txt (2 hunks)
  • src/axolotl/monkeypatch/ring_attn/adapters/batch.py (1 hunks)
  • src/axolotl/monkeypatch/ring_attn/patch.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (1 hunks)
🔇 Additional comments (2)
requirements.txt (2)

17-17: Patch bump to transformers==4.55.1 looks safe.

No public API changes are introduced here, and 4.55.1 is a patch release over 4.55.0. Keeping tokenizers>=0.21.1 is consistent with recent transformers constraints.


4-4: Validate bitsandbytes v0.47.0 CUDA wheel support across your CUDA/OS matrix

bitsandbytes 0.47.0 ships prebuilt wheels for CUDA 11.8–12.6 on Linux/Windows x86_64, but CUDA 12.6 binaries are reported to fail loading libbitsandbytes_cuda126.so in some environments (see GH #1703). For stability, target CUDA 11.8–12.5. If you must use 12.6, verify your CI/prod pipelines actually load the 12.6 binaries and consider setting BNB_CUDA_VERSION or adding fallback logic.

• File: requirements.txt (line 4)
• Supported prebuilt CUDA toolkits: 11.8 – 12.6 (Linux/Windows)
• Known issues with CUDA 12.6: missing/failed libbitsandbytes_cuda126.so, CPU-only fallback
• Recommended: restrict your CI/production CUDA matrix to 11.8–12.5, or add explicit tests for 12.6 and document workarounds

Comment on lines +23 to +31
try: # pylint: disable=duplicate-code
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid aliasing _flash_supports_window_size (likely a function) to a boolean flag.

If _flash_supports_window is absent and _flash_supports_window_size is a callable, aliasing it to _flash_supports_window and later using it in a boolean context will always evaluate truthy, potentially enabling sliding windows even when unsupported. Prefer a helper that queries the transformers module attribute(s) at runtime and returns a boolean.

Apply this refactor to robustly determine window support:

-try:  # pylint: disable=duplicate-code
-    from transformers.modeling_flash_attention_utils import _flash_supports_window
-except ImportError:
-    try:
-        from transformers.modeling_flash_attention_utils import (
-            _flash_supports_window_size as _flash_supports_window,
-        )
-    except ImportError:
-        _flash_supports_window = True
+try:  # pylint: disable=duplicate-code
+    import transformers.modeling_flash_attention_utils as _fau
+except ImportError:
+    _fau = None  # type: ignore
+
+def _flash_window_supported(sliding_window: int | None) -> bool:
+    # Default to True for ring-flash-attn path if utils are unavailable
+    if _fau is None:
+        return True
+    # Prefer _flash_supports_window if present
+    val = getattr(_fau, "_flash_supports_window", None)
+    try:
+        if callable(val):
+            return bool(val() if sliding_window is None else val(sliding_window))
+        if val is not None:
+            return bool(val)
+    except Exception:
+        pass
+    # Fallback to _flash_supports_window_size
+    val = getattr(_fau, "_flash_supports_window_size", None)
+    try:
+        if callable(val):
+            # If no explicit size is given, be conservative and assume True for ring-attn
+            return bool(val(sliding_window if sliding_window is not None else 0))
+        if val is not None:
+            return bool(val)
+    except Exception:
+        pass
+    # Final fallback for ring-flash-attn
+    return True
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try: # pylint: disable=duplicate-code
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
try: # pylint: disable=duplicate-code
import transformers.modeling_flash_attention_utils as _fau
except ImportError:
_fau = None # type: ignore
def _flash_window_supported(sliding_window: int | None) -> bool:
# Default to True for ring-flash-attn path if utils are unavailable
if _fau is None:
return True
# Prefer _flash_supports_window if present
val = getattr(_fau, "_flash_supports_window", None)
try:
if callable(val):
return bool(val() if sliding_window is None else val(sliding_window))
if val is not None:
return bool(val)
except Exception:
pass
# Fallback to _flash_supports_window_size
val = getattr(_fau, "_flash_supports_window_size", None)
try:
if callable(val):
# If no explicit size is given, be conservative and assume True for ring-attn
return bool(val(sliding_window if sliding_window is not None else 0))
if val is not None:
return bool(val)
except Exception:
pass
# Final fallback for ring-flash-attn
return True

Comment on lines +18 to +26
try: # pylint: disable=duplicate-code
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Same issue: aliasing _flash_supports_window_size to a boolean flag is fragile.

Mirror the robust runtime check pattern here to avoid always-True behavior on callable aliases and to accommodate transformers version drift.

-try:  # pylint: disable=duplicate-code
-    from transformers.modeling_flash_attention_utils import _flash_supports_window
-except ImportError:
-    try:
-        from transformers.modeling_flash_attention_utils import (
-            _flash_supports_window_size as _flash_supports_window,
-        )
-    except ImportError:
-        _flash_supports_window = True
+try:  # pylint: disable=duplicate-code
+    import transformers.modeling_flash_attention_utils as _fau
+except ImportError:
+    _fau = None  # type: ignore
+
+def _flash_window_supported(sliding_window: int | None) -> bool:
+    if _fau is None:
+        return True
+    val = getattr(_fau, "_flash_supports_window", None)
+    try:
+        if callable(val):
+            return bool(val() if sliding_window is None else val(sliding_window))
+        if val is not None:
+            return bool(val)
+    except Exception:
+        pass
+    val = getattr(_fau, "_flash_supports_window_size", None)
+    try:
+        if callable(val):
+            return bool(val(sliding_window if sliding_window is not None else 0))
+        if val is not None:
+            return bool(val)
+    except Exception:
+        pass
+    return True

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/ring_attn/patch.py around lines 18 to 26, the current
fallback aliases _flash_supports_window_size to a boolean which can cause
always-True behavior; instead, detect whether the imported symbol is callable
and, if it is not, wrap the boolean in a small runtime-check function so callers
expecting a function that accepts a window_size still get correct behavior.
Modify the import/fallback logic to: try the primary import, then try importing
_flash_supports_window_size and if that import yields a callable keep it as
_flash_supports_window, otherwise define _flash_supports_window = lambda
window_size: bool(_flash_supports_window_size) (or a simple def that returns the
boolean) and finally ensure a sane default function (e.g., returns True/False)
if both imports fail.

Comment on lines +1256 to 1258
transformers.modeling_flash_attention_utils._flash_supports_window = (
True
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Unconditional monkeypatch of a protected transformers internal may have unintended side effects.

Directly setting transformers.modeling_flash_attention_utils._flash_supports_window = True can affect other HF execution paths beyond ring-attn and may diverge from adapter modules that capture this flag at import time. Consider reading the flag dynamically at call sites (in your monkeypatch adapters) rather than forcing a global module-level override here, or set only if the attribute exists and is False to reduce blast radius.

@codecov

codecov Bot commented Aug 13, 2025

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 75.00000% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...rc/axolotl/monkeypatch/ring_attn/adapters/batch.py 20.00% 4 Missing ⚠️

📢 Thoughts on this report? Let us know!

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tests/e2e/patched/test_fsdp2_qlora.py (1)

1-1: Docstring is now slightly misleading; clarify scope.

The test suite no longer covers the Params4bit torch_function patching path. Update the module docstring to reflect that it validates the FSDP2 init-time patches instead.

-"""Integration tests for FSDP2 Params4bit patches."""
+"""Integration tests for FSDP2 init patches (Params4bit integration)."""
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 50047f5 and f585469.

📒 Files selected for processing (3)
  • src/axolotl/loaders/patch_manager.py (0 hunks)
  • src/axolotl/monkeypatch/fsdp2_qlora.py (0 hunks)
  • tests/e2e/patched/test_fsdp2_qlora.py (1 hunks)
💤 Files with no reviewable changes (2)
  • src/axolotl/loaders/patch_manager.py
  • src/axolotl/monkeypatch/fsdp2_qlora.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (10)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: preview
  • GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
  • GitHub Check: pre-commit

Comment on lines 3 to 5
import pytest
import torch
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Harden test against missing/changed Torch internals and ensure test isolation (restore patched methods).

  • Importing FSDPParam from a private Torch module at import time is brittle and can break test collection. Use pytest.importorskip inside the test to skip cleanly if the internal path moves/changes.
  • The test permanently mutates FSDPParam methods, which can leak state to other tests. Use pytest’s monkeypatch to capture/restore originals automatically.
  • Optional: Assert idempotency by re-applying the patches and ensuring the method objects don’t change a second time.

Apply the following changes:

-import pytest
-from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
+import pytest
-    @pytest.mark.integration
-    def test_fsdp2_init_patches(self):
-        """Test that all patches can be applied together."""
+    @pytest.mark.integration
+    def test_fsdp2_init_patches(self, monkeypatch):
+        """Verify FSDP2 init patches are applied and idempotent."""
+        # Import Torch internal lazily and skip if unavailable/changed
+        fsdp_mod = pytest.importorskip("torch.distributed.fsdp._fully_shard._fsdp_param")
+        FSDPParam = fsdp_mod.FSDPParam
         from axolotl.monkeypatch.fsdp2_qlora import (
             apply_init_sharded_param_patch,
             apply_init_unsharded_param_patch,
         )
 
-        # pylint: disable=protected-access
-        original_init_sharded = FSDPParam._init_sharded_param
-        original_init_unsharded = FSDPParam.init_unsharded_param
+        # Ensure test isolation by restoring originals after the test
+        # monkeypatch.setattr records the current value and restores it at teardown.
+        monkeypatch.setattr(
+            FSDPParam, "_init_sharded_param", FSDPParam._init_sharded_param, raising=False
+        )
+        monkeypatch.setattr(
+            FSDPParam, "init_unsharded_param", FSDPParam.init_unsharded_param, raising=False
+        )
+
+        # pylint: disable=protected-access
+        original_init_sharded = FSDPParam._init_sharded_param
+        original_init_unsharded = FSDPParam.init_unsharded_param
 
         # Apply patches
         apply_init_sharded_param_patch()
         apply_init_unsharded_param_patch()
 
         assert (
             # pylint: disable=protected-access
             FSDPParam._init_sharded_param
             != original_init_sharded
         ), "_init_sharded_param was not patched"
         assert (
             FSDPParam.init_unsharded_param != original_init_unsharded
         ), "init_unsharded_param was not patched"
+
+        # Optional: patches should be idempotent (re-applying should not re-wrap)
+        patched_init_sharded = FSDPParam._init_sharded_param
+        patched_init_unsharded = FSDPParam.init_unsharded_param
+        apply_init_sharded_param_patch()
+        apply_init_unsharded_param_patch()
+        assert (
+            FSDPParam._init_sharded_param is patched_init_sharded
+        ), "_init_sharded_param patch should be idempotent"
+        assert (
+            FSDPParam.init_unsharded_param is patched_init_unsharded
+        ), "init_unsharded_param patch should be idempotent"

Follow-up:

  • If these patches may be pre-applied by other tests/modules, confirm the patch functions are no-ops when called repeatedly (idempotent). If not, consider adding guards in the patch implementations, or adapt the assertions to tolerate an already-patched state.

Also applies to: 11-33


🏁 Script executed:

#!/bin/bash
set -e

echo "===== Showing tests/e2e/patched/test_fsdp2_qlora.py ====="
sed -n '1,200p' tests/e2e/patched/test_fsdp2_qlora.py || echo "File not found: tests/e2e/patched/test_fsdp2_qlora.py"

echo
echo "===== Locating patch function definitions ====="
rg -n -C3 "def apply_init_sharded_param_patch" -n . || echo "No apply_init_sharded_param_patch definition found"
rg -n -C3 "def apply_init_unsharded_param_patch" -n . || echo "No apply_init_unsharded_param_patch definition found"

Length of output: 2639


🏁 Script executed:

#!/bin/bash
set -e

echo "===== Showing the first 200 lines of src/axolotl/monkeypatch/fsdp2_qlora.py ====="
sed -n '1,200p' src/axolotl/monkeypatch/fsdp2_qlora.py || echo "File not found"

Length of output: 5879


Harden test against Torch internal changes and ensure isolation

  • Avoid importing FSDPParam at module load from a private path; this can break collection if Torch internals move.
  • Lazily import via pytest.importorskip inside the test to skip cleanly on missing/renamed internals.
  • Use the monkeypatch fixture to snapshot and restore FSDPParam methods, preventing cross-test leakage.
  • (Optional) Re-apply the patches to assert idempotency.

Proposed changes:

-import pytest
-from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
+import pytest
-    @pytest.mark.integration
-    def test_fsdp2_init_patches(self):
-        """Test that all patches can be applied together."""
+    @pytest.mark.integration
+    def test_fsdp2_init_patches(self, monkeypatch):
+        """Verify FSDPParam patches apply and are idempotent."""
+        # Lazily import Torch internals and skip if unavailable
+        fsdp_mod = pytest.importorskip("torch.distributed.fsdp._fully_shard._fsdp_param")
+        FSDPParam = fsdp_mod.FSDPParam
+
+        # Snapshot originals for automatic restore
+        monkeypatch.setattr(
+            FSDPParam, "_init_sharded_param", FSDPParam._init_sharded_param, raising=False
+        )
+        monkeypatch.setattr(
+            FSDPParam, "init_unsharded_param", FSDPParam.init_unsharded_param, raising=False
+        )
+
+        from axolotl.monkeypatch.fsdp2_qlora import (
+            apply_init_sharded_param_patch,
+            apply_init_unsharded_param_patch,
+        )
+
+        # Capture originals
+        original_init_sharded = FSDPParam._init_sharded_param
+        original_init_unsharded = FSDPParam.init_unsharded_param
+
+        # Apply patches
+        apply_init_sharded_param_patch()
+        apply_init_unsharded_param_patch()
+
+        # Confirm methods changed
+        assert FSDPParam._init_sharded_param is not original_init_sharded, \
+            "_init_sharded_param was not patched"
+        assert FSDPParam.init_unsharded_param is not original_init_unsharded, \
+            "init_unsharded_param was not patched"
+
+        # Optional: ensure idempotency
+        patched_sharded = FSDPParam._init_sharded_param
+        patched_unsharded = FSDPParam.init_unsharded_param
+        apply_init_sharded_param_patch()
+        apply_init_unsharded_param_patch()
+        assert FSDPParam._init_sharded_param is patched_sharded, \
+            "_init_sharded_param patch should be idempotent"
+        assert FSDPParam.init_unsharded_param is patched_unsharded, \
+            "init_unsharded_param patch should be idempotent"

Follow-up: if your patch functions may already have been applied elsewhere, ensure they’re no-ops on repeated calls or adapt the idempotency checks accordingly.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import pytest
import torch
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# --- before (tests/e2e/patched/test_fsdp2_qlora.py) ---
-import pytest
-from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
+import pytest
# ... other imports or code ...
class TestFSDP2QLoRA:
- @pytest.mark.integration
- def test_fsdp2_init_patches(self):
- """Test that all patches can be applied together."""
+ @pytest.mark.integration
+ def test_fsdp2_init_patches(self, monkeypatch):
+ """Verify FSDPParam patches apply and are idempotent."""
+ # Lazily import Torch internals and skip if unavailable
+ fsdp_mod = pytest.importorskip("torch.distributed.fsdp._fully_shard._fsdp_param")
+ FSDPParam = fsdp_mod.FSDPParam
+
+ # Snapshot originals for automatic restore
+ monkeypatch.setattr(
+ FSDPParam, "_init_sharded_param", FSDPParam._init_sharded_param, raising=False
+ )
+ monkeypatch.setattr(
+ FSDPParam, "init_unsharded_param", FSDPParam.init_unsharded_param, raising=False
+ )
+
+ from axolotl.monkeypatch.fsdp2_qlora import (
+ apply_init_sharded_param_patch,
+ apply_init_unsharded_param_patch,
+ )
+
+ # Capture originals
+ original_init_sharded = FSDPParam._init_sharded_param
+ original_init_unsharded = FSDPParam.init_unsharded_param
+
+ # Apply patches
+ apply_init_sharded_param_patch()
+ apply_init_unsharded_param_patch()
+
+ # Confirm methods changed
+ assert FSDPParam._init_sharded_param is not original_init_sharded, \
+ "_init_sharded_param was not patched"
+ assert FSDPParam.init_unsharded_param is not original_init_unsharded, \
+ "init_unsharded_param was not patched"
+
+ # Optional: ensure idempotency
+ patched_sharded = FSDPParam._init_sharded_param
+ patched_unsharded = FSDPParam.init_unsharded_param
+ apply_init_sharded_param_patch()
+ apply_init_unsharded_param_patch()
+ assert FSDPParam._init_sharded_param is patched_sharded, \
+ "_init_sharded_param patch should be idempotent"
+ assert FSDPParam.init_unsharded_param is patched_unsharded, \
+ "init_unsharded_param patch should be idempotent"

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
src/axolotl/utils/schemas/validation.py (1)

1253-1276: Narrow the monkeypatch, fix misleading ImportError, and address Ruff B010 (use direct assignment over setattr)

Issues:

  • Line 1258 unconditionally and globally monkeypatches a protected transformers internal, which increases blast radius and may have unintended side effects beyond ring-attn.
  • Lines 1261-1270 set the same attributes twice (direct assignment and setattr) and trigger Ruff B010; direct assignment is preferred.
  • The broad try/except causes any ImportError within Lines 1253-1276 (including a missing is_flash_attn_greater_or_equal in some transformers versions) to surface as a “ring_flash_attn not installed” error, which is misleading.
  • The injection of is_flash_attn_greater_or_equal should be conditional (only if absent), and failures here shouldn’t block ring-flash-attn presence checks.

Apply a safer, narrower patch that:

  • Imports modeling_flash_attention_utils as a module alias and only sets attributes if they’re missing (or False).
  • Avoids redundant setattr calls.
  • Separates ring_flash_attn import into its own try/except so the error message is accurate.
             try:
-                import transformers.modeling_flash_attention_utils
-                from transformers.utils import is_flash_attn_greater_or_equal
-
-                # pylint: disable=protected-access
-                transformers.modeling_flash_attention_utils._flash_supports_window = (
-                    True
-                )
-                setattr(
-                    sys.modules["transformers.modeling_flash_attention_utils"],
-                    "_flash_supports_window",
-                    True,
-                )
-                setattr(
-                    sys.modules["transformers.modeling_flash_attention_utils"],
-                    "_flash_supports_window_size",
-                    True,
-                )
-                setattr(
-                    sys.modules["transformers.modeling_flash_attention_utils"],
-                    "is_flash_attn_greater_or_equal",
-                    is_flash_attn_greater_or_equal,
-                )
-                import ring_flash_attn  # noqa: F401 # pylint:disable=unused-import
-            except ImportError as exception:
+                import transformers.modeling_flash_attention_utils as mfa
+            except Exception:
+                mfa = None
+
+            # Only patch if the module was imported; keep scope minimal and idempotent.
+            if mfa is not None:
+                # pylint: disable=protected-access
+                if not hasattr(mfa, "_flash_supports_window"):
+                    mfa._flash_supports_window = True
+                if not hasattr(mfa, "_flash_supports_window_size"):
+                    mfa._flash_supports_window_size = True
+                # Inject utility function only if absent; ignore if not available in this HF version.
+                try:
+                    from transformers.utils import (  # type: ignore
+                        is_flash_attn_greater_or_equal as _is_fa_ge,
+                    )
+                    if not hasattr(mfa, "is_flash_attn_greater_or_equal"):
+                        mfa.is_flash_attn_greater_or_equal = _is_fa_ge
+                except Exception:
+                    pass
+
+            # Validate ring_flash_attn availability with a focused try/except to avoid misleading error messages.
+            try:
+                import ring_flash_attn  # noqa: F401  # pylint:disable=unused-import
+            except ImportError as exception:
                 raise ImportError(
                     "context_parallel_size > 1 but ring_flash_attn is not installed. "
                     "Please install it with `pip install axolotl[ring-flash-attn] "
                     "or `pip install ring-flash-attn>=0.1.4`."
                 ) from exception

Benefits:

  • Reduces global side effects, aligns with earlier feedback to avoid unconditional overrides, and keeps compatibility across transformers versions.
  • Removes redundant setattr calls and satisfies Ruff B010.
  • Produces accurate error messages when ring_flash_attn is missing.

Note: This echoes prior feedback about unconditional monkeypatching protected internals; the new patch constrains it and avoids double-setting.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f585469 and e6d722f.

📒 Files selected for processing (2)
  • docker/Dockerfile-base (1 hunks)
  • src/axolotl/utils/schemas/validation.py (2 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/validation.py

1261-1265: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)


1266-1270: Do not call setattr with a constant attribute value. It is not any safer than normal property access.

Replace setattr with assignment

(B010)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: build-base-uv (126, 12.6.3, 3.11, 2.7.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-uv-base)
  • GitHub Check: build-base-uv (128, 12.8.1, 3.11, 2.7.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-uv-base)
  • GitHub Check: build-base-uv (128, 12.8.1, 3.11, 2.8.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-uv-base)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
  • GitHub Check: build-base (128, 12.8.1, 3.11, 2.7.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
  • GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
  • GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
  • GitHub Check: build-base (128, 12.8.1, 3.11, 2.8.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
  • GitHub Check: build-base-uv (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-uv-base)
  • GitHub Check: preview
  • GitHub Check: pre-commit

Comment thread docker/Dockerfile-base
Comment on lines +40 to 42
CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Avoid forcing CAUSAL_CONV1D to skip CUDA build in a CUDA base image

Line 40 sets CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE for a CUDA-enabled base image. This likely disables GPU kernels for causal_conv1d and can degrade performance or break GPU paths that expect the CUDA extension. If the intent is to leverage GPU acceleration (which the base image suggests), we should not skip the CUDA build by default.

Proposed fix (remove the env var so GPU extensions are built when available):

-    CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
+    python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \

If you need a CPU-only option for certain builds, make this conditional via a build ARG rather than unconditional in the CUDA image.

Please confirm that we indeed want CPU-only causal_conv1d in this CUDA image. If not, apply the diff above. If yes, consider gating it with a build ARG (e.g., SKIP_CCONV_CUDA_BUILD=true).


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Searching repository for causal_conv1d / CAUSAL_CONV1D_SKIP_CUDA_BUILD ==="
rg -n --hidden --no-ignore -S "CAUSAL_CONV1D_SKIP_CUDA_BUILD|CAUSAL_CONV1D|causal_conv1d" || true

echo
echo "=== Show docker/Dockerfile-base (if present) ==="
if [ -f docker/Dockerfile-base ]; then
  sed -n '1,240p' docker/Dockerfile-base || true
else
  echo "docker/Dockerfile-base not found at that path. Listing Dockerfiles found:"
  rg -n --hidden --no-ignore -S "Dockerfile" || true
fi

echo
echo "=== Search for CUDA-related base images in Dockerfiles / top-level FROM lines ==="
rg -n --hidden --no-ignore -S "FROM .*cuda|FROM .*nvidia|FROM .*pytorch|FROM .*cudnn|FROM .*nvidia/cuda" || true

echo
echo "=== Search CI/workflows for build args or env that mention CAUSAL_CONV1D or skipping CUDA build ==="
rg -n --hidden --no-ignore -S "CAUSAL_CONV1D_SKIP_CUDA_BUILD|CAUSAL_CONV1D|SKIP_CCONV_CUDA_BUILD|SKIP_CCONV|skip_causal|causal_conv1d" || true

Length of output: 5095


Don't force CAUSAL_CONV1D to skip CUDA build in a CUDA base image

This Dockerfile uses an nvidia/cuda base image; unconditionally setting CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE prevents building the CUDA extension (disabling GPU acceleration) and is inconsistent with other Dockerfiles in the repo.

Files to review:

  • docker/Dockerfile-base — line with: CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
  • docker/Dockerfile-base-nightly, docker/Dockerfile-base-next — install causal_conv1d from git without skipping CUDA
  • docker/Dockerfile-uv-base — installs causal_conv1d from git (no skip)
  • setup.py — lists causal_conv1d as a dependency

Proposed minimal change (remove the forced skip):

-    CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
+    python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \

If you need a CPU-only build option, gate it with a build ARG instead of forcing it in the CUDA image. Example approach:

  • Add near the top: ARG SKIP_CCONV_CUDA_BUILD="FALSE"
  • Replace the install line with a conditional in the RUN so the skip is only applied when SKIP_CCONV_CUDA_BUILD is explicitly set to "TRUE".

Please remove the env var or switch to the build-arg gating to keep GPU support available by default.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge
python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge

@winglian winglian merged commit 09145de into main Aug 13, 2025
29 of 30 checks passed
@winglian winglian deleted the transformers-upgrade branch August 13, 2025 23:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants