upgrade transformers==4.55.1 and bitsandbytes==0.47.0#3064
Conversation
📝 WalkthroughWalkthroughBumps two dependency pins in requirements.txt; extends and hardens flash-attention import fallbacks and default behavior; removes BitsAndBytes torch_function patching and replaces it with FSDPParam init-time patches; removes invocation of the BnB torch-function patch; narrows related end-to-end tests accordingly. No public API signatures changed. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~35 minutes Possibly related PRs
Suggested labels
Suggested reviewers
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
|
📖 Documentation Preview: https://689cf6e1e1a684da77c5820c--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit e6d722f |
|
converting this PR from a draft to ready seems to still skip the multigpu CI. manual run here: https://github.com/axolotl-ai-cloud/axolotl/actions/runs/16943460540 |
There was a problem hiding this comment.
Actionable comments posted: 3
🔭 Outside diff range comments (2)
src/axolotl/monkeypatch/ring_attn/adapters/batch.py (1)
121-129: Use the runtime helper to decide sliding window support.This ensures correct behavior across transformers versions (attr vs func) and picks up any runtime overrides.
- use_sliding_windows = ( - _flash_supports_window - and sliding_window is not None - and key_states.shape[1] > sliding_window - ) + use_sliding_windows = ( + sliding_window is not None + and key_states.shape[1] > sliding_window + and _flash_window_supported(sliding_window) + )src/axolotl/monkeypatch/ring_attn/patch.py (1)
90-99: Use the runtime helper for sliding-window gating.Prevents accidental enabling of sliding windows on unsupported FA2 versions.
- use_sliding_windows = ( - _flash_supports_window - and sliding_window is not None - and key_states.shape[1] > sliding_window - ) + use_sliding_windows = ( + sliding_window is not None + and key_states.shape[1] > sliding_window + and _flash_window_supported(sliding_window) + )
🧹 Nitpick comments (1)
src/axolotl/utils/schemas/validation.py (1)
1261-1265: Minor: fix error message formatting.Unbalanced quotes/backticks around the install hints.
- "Please install it with `pip install axolotl[ring-flash-attn] - "or `pip install ring-flash-attn>=0.1.4`." + "Please install it with `pip install 'axolotl[ring-flash-attn]'` " + "or `pip install 'ring-flash-attn>=0.1.4'`."
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
requirements.txt(2 hunks)src/axolotl/monkeypatch/ring_attn/adapters/batch.py(1 hunks)src/axolotl/monkeypatch/ring_attn/patch.py(1 hunks)src/axolotl/utils/schemas/validation.py(1 hunks)
🔇 Additional comments (2)
requirements.txt (2)
17-17: Patch bump to transformers==4.55.1 looks safe.No public API changes are introduced here, and 4.55.1 is a patch release over 4.55.0. Keeping tokenizers>=0.21.1 is consistent with recent transformers constraints.
4-4: Validate bitsandbytes v0.47.0 CUDA wheel support across your CUDA/OS matrixbitsandbytes 0.47.0 ships prebuilt wheels for CUDA 11.8–12.6 on Linux/Windows x86_64, but CUDA 12.6 binaries are reported to fail loading
libbitsandbytes_cuda126.soin some environments (see GH #1703). For stability, target CUDA 11.8–12.5. If you must use 12.6, verify your CI/prod pipelines actually load the 12.6 binaries and consider settingBNB_CUDA_VERSIONor adding fallback logic.• File: requirements.txt (line 4)
• Supported prebuilt CUDA toolkits: 11.8 – 12.6 (Linux/Windows)
• Known issues with CUDA 12.6: missing/failedlibbitsandbytes_cuda126.so, CPU-only fallback
• Recommended: restrict your CI/production CUDA matrix to 11.8–12.5, or add explicit tests for 12.6 and document workarounds
| try: # pylint: disable=duplicate-code | ||
| from transformers.modeling_flash_attention_utils import _flash_supports_window | ||
| except ImportError: | ||
| from transformers.modeling_flash_attention_utils import ( | ||
| _flash_supports_window_size as _flash_supports_window, | ||
| ) | ||
| try: | ||
| from transformers.modeling_flash_attention_utils import ( | ||
| _flash_supports_window_size as _flash_supports_window, | ||
| ) | ||
| except ImportError: | ||
| _flash_supports_window = True |
There was a problem hiding this comment.
Avoid aliasing _flash_supports_window_size (likely a function) to a boolean flag.
If _flash_supports_window is absent and _flash_supports_window_size is a callable, aliasing it to _flash_supports_window and later using it in a boolean context will always evaluate truthy, potentially enabling sliding windows even when unsupported. Prefer a helper that queries the transformers module attribute(s) at runtime and returns a boolean.
Apply this refactor to robustly determine window support:
-try: # pylint: disable=duplicate-code
- from transformers.modeling_flash_attention_utils import _flash_supports_window
-except ImportError:
- try:
- from transformers.modeling_flash_attention_utils import (
- _flash_supports_window_size as _flash_supports_window,
- )
- except ImportError:
- _flash_supports_window = True
+try: # pylint: disable=duplicate-code
+ import transformers.modeling_flash_attention_utils as _fau
+except ImportError:
+ _fau = None # type: ignore
+
+def _flash_window_supported(sliding_window: int | None) -> bool:
+ # Default to True for ring-flash-attn path if utils are unavailable
+ if _fau is None:
+ return True
+ # Prefer _flash_supports_window if present
+ val = getattr(_fau, "_flash_supports_window", None)
+ try:
+ if callable(val):
+ return bool(val() if sliding_window is None else val(sliding_window))
+ if val is not None:
+ return bool(val)
+ except Exception:
+ pass
+ # Fallback to _flash_supports_window_size
+ val = getattr(_fau, "_flash_supports_window_size", None)
+ try:
+ if callable(val):
+ # If no explicit size is given, be conservative and assume True for ring-attn
+ return bool(val(sliding_window if sliding_window is not None else 0))
+ if val is not None:
+ return bool(val)
+ except Exception:
+ pass
+ # Final fallback for ring-flash-attn
+ return True📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| try: # pylint: disable=duplicate-code | |
| from transformers.modeling_flash_attention_utils import _flash_supports_window | |
| except ImportError: | |
| from transformers.modeling_flash_attention_utils import ( | |
| _flash_supports_window_size as _flash_supports_window, | |
| ) | |
| try: | |
| from transformers.modeling_flash_attention_utils import ( | |
| _flash_supports_window_size as _flash_supports_window, | |
| ) | |
| except ImportError: | |
| _flash_supports_window = True | |
| try: # pylint: disable=duplicate-code | |
| import transformers.modeling_flash_attention_utils as _fau | |
| except ImportError: | |
| _fau = None # type: ignore | |
| def _flash_window_supported(sliding_window: int | None) -> bool: | |
| # Default to True for ring-flash-attn path if utils are unavailable | |
| if _fau is None: | |
| return True | |
| # Prefer _flash_supports_window if present | |
| val = getattr(_fau, "_flash_supports_window", None) | |
| try: | |
| if callable(val): | |
| return bool(val() if sliding_window is None else val(sliding_window)) | |
| if val is not None: | |
| return bool(val) | |
| except Exception: | |
| pass | |
| # Fallback to _flash_supports_window_size | |
| val = getattr(_fau, "_flash_supports_window_size", None) | |
| try: | |
| if callable(val): | |
| # If no explicit size is given, be conservative and assume True for ring-attn | |
| return bool(val(sliding_window if sliding_window is not None else 0)) | |
| if val is not None: | |
| return bool(val) | |
| except Exception: | |
| pass | |
| # Final fallback for ring-flash-attn | |
| return True |
| try: # pylint: disable=duplicate-code | ||
| from transformers.modeling_flash_attention_utils import _flash_supports_window | ||
| except ImportError: | ||
| from transformers.modeling_flash_attention_utils import ( | ||
| _flash_supports_window_size as _flash_supports_window, | ||
| ) | ||
| try: | ||
| from transformers.modeling_flash_attention_utils import ( | ||
| _flash_supports_window_size as _flash_supports_window, | ||
| ) | ||
| except ImportError: | ||
| _flash_supports_window = True |
There was a problem hiding this comment.
Same issue: aliasing _flash_supports_window_size to a boolean flag is fragile.
Mirror the robust runtime check pattern here to avoid always-True behavior on callable aliases and to accommodate transformers version drift.
-try: # pylint: disable=duplicate-code
- from transformers.modeling_flash_attention_utils import _flash_supports_window
-except ImportError:
- try:
- from transformers.modeling_flash_attention_utils import (
- _flash_supports_window_size as _flash_supports_window,
- )
- except ImportError:
- _flash_supports_window = True
+try: # pylint: disable=duplicate-code
+ import transformers.modeling_flash_attention_utils as _fau
+except ImportError:
+ _fau = None # type: ignore
+
+def _flash_window_supported(sliding_window: int | None) -> bool:
+ if _fau is None:
+ return True
+ val = getattr(_fau, "_flash_supports_window", None)
+ try:
+ if callable(val):
+ return bool(val() if sliding_window is None else val(sliding_window))
+ if val is not None:
+ return bool(val)
+ except Exception:
+ pass
+ val = getattr(_fau, "_flash_supports_window_size", None)
+ try:
+ if callable(val):
+ return bool(val(sliding_window if sliding_window is not None else 0))
+ if val is not None:
+ return bool(val)
+ except Exception:
+ pass
+ return TrueCommittable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/ring_attn/patch.py around lines 18 to 26, the current
fallback aliases _flash_supports_window_size to a boolean which can cause
always-True behavior; instead, detect whether the imported symbol is callable
and, if it is not, wrap the boolean in a small runtime-check function so callers
expecting a function that accepts a window_size still get correct behavior.
Modify the import/fallback logic to: try the primary import, then try importing
_flash_supports_window_size and if that import yields a callable keep it as
_flash_supports_window, otherwise define _flash_supports_window = lambda
window_size: bool(_flash_supports_window_size) (or a simple def that returns the
boolean) and finally ensure a sane default function (e.g., returns True/False)
if both imports fail.
| transformers.modeling_flash_attention_utils._flash_supports_window = ( | ||
| True | ||
| ) |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Unconditional monkeypatch of a protected transformers internal may have unintended side effects.
Directly setting transformers.modeling_flash_attention_utils._flash_supports_window = True can affect other HF execution paths beyond ring-attn and may diverge from adapter modules that capture this flag at import time. Consider reading the flag dynamically at call sites (in your monkeypatch adapters) rather than forcing a global module-level override here, or set only if the attribute exists and is False to reduce blast radius.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/e2e/patched/test_fsdp2_qlora.py (1)
1-1: Docstring is now slightly misleading; clarify scope.The test suite no longer covers the Params4bit torch_function patching path. Update the module docstring to reflect that it validates the FSDP2 init-time patches instead.
-"""Integration tests for FSDP2 Params4bit patches.""" +"""Integration tests for FSDP2 init patches (Params4bit integration)."""
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/axolotl/loaders/patch_manager.py(0 hunks)src/axolotl/monkeypatch/fsdp2_qlora.py(0 hunks)tests/e2e/patched/test_fsdp2_qlora.py(1 hunks)
💤 Files with no reviewable changes (2)
- src/axolotl/loaders/patch_manager.py
- src/axolotl/monkeypatch/fsdp2_qlora.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (10)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: preview
- GitHub Check: test-axolotl-multigpu (126, 12.6.3, 3.11, 2.7.1, vllm, 2, true)
- GitHub Check: pre-commit
| import pytest | ||
| import torch | ||
| from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam | ||
|
|
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Harden test against missing/changed Torch internals and ensure test isolation (restore patched methods).
- Importing FSDPParam from a private Torch module at import time is brittle and can break test collection. Use pytest.importorskip inside the test to skip cleanly if the internal path moves/changes.
- The test permanently mutates FSDPParam methods, which can leak state to other tests. Use pytest’s monkeypatch to capture/restore originals automatically.
- Optional: Assert idempotency by re-applying the patches and ensuring the method objects don’t change a second time.
Apply the following changes:
-import pytest
-from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
+import pytest- @pytest.mark.integration
- def test_fsdp2_init_patches(self):
- """Test that all patches can be applied together."""
+ @pytest.mark.integration
+ def test_fsdp2_init_patches(self, monkeypatch):
+ """Verify FSDP2 init patches are applied and idempotent."""
+ # Import Torch internal lazily and skip if unavailable/changed
+ fsdp_mod = pytest.importorskip("torch.distributed.fsdp._fully_shard._fsdp_param")
+ FSDPParam = fsdp_mod.FSDPParam
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)
- # pylint: disable=protected-access
- original_init_sharded = FSDPParam._init_sharded_param
- original_init_unsharded = FSDPParam.init_unsharded_param
+ # Ensure test isolation by restoring originals after the test
+ # monkeypatch.setattr records the current value and restores it at teardown.
+ monkeypatch.setattr(
+ FSDPParam, "_init_sharded_param", FSDPParam._init_sharded_param, raising=False
+ )
+ monkeypatch.setattr(
+ FSDPParam, "init_unsharded_param", FSDPParam.init_unsharded_param, raising=False
+ )
+
+ # pylint: disable=protected-access
+ original_init_sharded = FSDPParam._init_sharded_param
+ original_init_unsharded = FSDPParam.init_unsharded_param
# Apply patches
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
assert (
# pylint: disable=protected-access
FSDPParam._init_sharded_param
!= original_init_sharded
), "_init_sharded_param was not patched"
assert (
FSDPParam.init_unsharded_param != original_init_unsharded
), "init_unsharded_param was not patched"
+
+ # Optional: patches should be idempotent (re-applying should not re-wrap)
+ patched_init_sharded = FSDPParam._init_sharded_param
+ patched_init_unsharded = FSDPParam.init_unsharded_param
+ apply_init_sharded_param_patch()
+ apply_init_unsharded_param_patch()
+ assert (
+ FSDPParam._init_sharded_param is patched_init_sharded
+ ), "_init_sharded_param patch should be idempotent"
+ assert (
+ FSDPParam.init_unsharded_param is patched_init_unsharded
+ ), "init_unsharded_param patch should be idempotent"Follow-up:
- If these patches may be pre-applied by other tests/modules, confirm the patch functions are no-ops when called repeatedly (idempotent). If not, consider adding guards in the patch implementations, or adapt the assertions to tolerate an already-patched state.
Also applies to: 11-33
🏁 Script executed:
#!/bin/bash
set -e
echo "===== Showing tests/e2e/patched/test_fsdp2_qlora.py ====="
sed -n '1,200p' tests/e2e/patched/test_fsdp2_qlora.py || echo "File not found: tests/e2e/patched/test_fsdp2_qlora.py"
echo
echo "===== Locating patch function definitions ====="
rg -n -C3 "def apply_init_sharded_param_patch" -n . || echo "No apply_init_sharded_param_patch definition found"
rg -n -C3 "def apply_init_unsharded_param_patch" -n . || echo "No apply_init_unsharded_param_patch definition found"Length of output: 2639
🏁 Script executed:
#!/bin/bash
set -e
echo "===== Showing the first 200 lines of src/axolotl/monkeypatch/fsdp2_qlora.py ====="
sed -n '1,200p' src/axolotl/monkeypatch/fsdp2_qlora.py || echo "File not found"Length of output: 5879
Harden test against Torch internal changes and ensure isolation
- Avoid importing
FSDPParamat module load from a private path; this can break collection if Torch internals move. - Lazily import via
pytest.importorskipinside the test to skip cleanly on missing/renamed internals. - Use the
monkeypatchfixture to snapshot and restoreFSDPParammethods, preventing cross-test leakage. - (Optional) Re-apply the patches to assert idempotency.
Proposed changes:
-import pytest
-from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
+import pytest- @pytest.mark.integration
- def test_fsdp2_init_patches(self):
- """Test that all patches can be applied together."""
+ @pytest.mark.integration
+ def test_fsdp2_init_patches(self, monkeypatch):
+ """Verify FSDPParam patches apply and are idempotent."""
+ # Lazily import Torch internals and skip if unavailable
+ fsdp_mod = pytest.importorskip("torch.distributed.fsdp._fully_shard._fsdp_param")
+ FSDPParam = fsdp_mod.FSDPParam
+
+ # Snapshot originals for automatic restore
+ monkeypatch.setattr(
+ FSDPParam, "_init_sharded_param", FSDPParam._init_sharded_param, raising=False
+ )
+ monkeypatch.setattr(
+ FSDPParam, "init_unsharded_param", FSDPParam.init_unsharded_param, raising=False
+ )
+
+ from axolotl.monkeypatch.fsdp2_qlora import (
+ apply_init_sharded_param_patch,
+ apply_init_unsharded_param_patch,
+ )
+
+ # Capture originals
+ original_init_sharded = FSDPParam._init_sharded_param
+ original_init_unsharded = FSDPParam.init_unsharded_param
+
+ # Apply patches
+ apply_init_sharded_param_patch()
+ apply_init_unsharded_param_patch()
+
+ # Confirm methods changed
+ assert FSDPParam._init_sharded_param is not original_init_sharded, \
+ "_init_sharded_param was not patched"
+ assert FSDPParam.init_unsharded_param is not original_init_unsharded, \
+ "init_unsharded_param was not patched"
+
+ # Optional: ensure idempotency
+ patched_sharded = FSDPParam._init_sharded_param
+ patched_unsharded = FSDPParam.init_unsharded_param
+ apply_init_sharded_param_patch()
+ apply_init_unsharded_param_patch()
+ assert FSDPParam._init_sharded_param is patched_sharded, \
+ "_init_sharded_param patch should be idempotent"
+ assert FSDPParam.init_unsharded_param is patched_unsharded, \
+ "init_unsharded_param patch should be idempotent"Follow-up: if your patch functions may already have been applied elsewhere, ensure they’re no-ops on repeated calls or adapt the idempotency checks accordingly.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| import pytest | |
| import torch | |
| from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam | |
| # --- before (tests/e2e/patched/test_fsdp2_qlora.py) --- | |
| -import pytest | |
| -from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam | |
| +import pytest | |
| # ... other imports or code ... | |
| class TestFSDP2QLoRA: | |
| - @pytest.mark.integration | |
| - def test_fsdp2_init_patches(self): | |
| - """Test that all patches can be applied together.""" | |
| + @pytest.mark.integration | |
| + def test_fsdp2_init_patches(self, monkeypatch): | |
| + """Verify FSDPParam patches apply and are idempotent.""" | |
| + # Lazily import Torch internals and skip if unavailable | |
| + fsdp_mod = pytest.importorskip("torch.distributed.fsdp._fully_shard._fsdp_param") | |
| + FSDPParam = fsdp_mod.FSDPParam | |
| + | |
| + # Snapshot originals for automatic restore | |
| + monkeypatch.setattr( | |
| + FSDPParam, "_init_sharded_param", FSDPParam._init_sharded_param, raising=False | |
| + ) | |
| + monkeypatch.setattr( | |
| + FSDPParam, "init_unsharded_param", FSDPParam.init_unsharded_param, raising=False | |
| + ) | |
| + | |
| + from axolotl.monkeypatch.fsdp2_qlora import ( | |
| + apply_init_sharded_param_patch, | |
| + apply_init_unsharded_param_patch, | |
| + ) | |
| + | |
| + # Capture originals | |
| + original_init_sharded = FSDPParam._init_sharded_param | |
| + original_init_unsharded = FSDPParam.init_unsharded_param | |
| + | |
| + # Apply patches | |
| + apply_init_sharded_param_patch() | |
| + apply_init_unsharded_param_patch() | |
| + | |
| + # Confirm methods changed | |
| + assert FSDPParam._init_sharded_param is not original_init_sharded, \ | |
| + "_init_sharded_param was not patched" | |
| + assert FSDPParam.init_unsharded_param is not original_init_unsharded, \ | |
| + "init_unsharded_param was not patched" | |
| + | |
| + # Optional: ensure idempotency | |
| + patched_sharded = FSDPParam._init_sharded_param | |
| + patched_unsharded = FSDPParam.init_unsharded_param | |
| + apply_init_sharded_param_patch() | |
| + apply_init_unsharded_param_patch() | |
| + assert FSDPParam._init_sharded_param is patched_sharded, \ | |
| + "_init_sharded_param patch should be idempotent" | |
| + assert FSDPParam.init_unsharded_param is patched_unsharded, \ | |
| + "init_unsharded_param patch should be idempotent" |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
src/axolotl/utils/schemas/validation.py (1)
1253-1276: Narrow the monkeypatch, fix misleading ImportError, and address Ruff B010 (use direct assignment over setattr)Issues:
- Line 1258 unconditionally and globally monkeypatches a protected transformers internal, which increases blast radius and may have unintended side effects beyond ring-attn.
- Lines 1261-1270 set the same attributes twice (direct assignment and setattr) and trigger Ruff B010; direct assignment is preferred.
- The broad try/except causes any ImportError within Lines 1253-1276 (including a missing is_flash_attn_greater_or_equal in some transformers versions) to surface as a “ring_flash_attn not installed” error, which is misleading.
- The injection of is_flash_attn_greater_or_equal should be conditional (only if absent), and failures here shouldn’t block ring-flash-attn presence checks.
Apply a safer, narrower patch that:
- Imports modeling_flash_attention_utils as a module alias and only sets attributes if they’re missing (or False).
- Avoids redundant setattr calls.
- Separates ring_flash_attn import into its own try/except so the error message is accurate.
try: - import transformers.modeling_flash_attention_utils - from transformers.utils import is_flash_attn_greater_or_equal - - # pylint: disable=protected-access - transformers.modeling_flash_attention_utils._flash_supports_window = ( - True - ) - setattr( - sys.modules["transformers.modeling_flash_attention_utils"], - "_flash_supports_window", - True, - ) - setattr( - sys.modules["transformers.modeling_flash_attention_utils"], - "_flash_supports_window_size", - True, - ) - setattr( - sys.modules["transformers.modeling_flash_attention_utils"], - "is_flash_attn_greater_or_equal", - is_flash_attn_greater_or_equal, - ) - import ring_flash_attn # noqa: F401 # pylint:disable=unused-import - except ImportError as exception: + import transformers.modeling_flash_attention_utils as mfa + except Exception: + mfa = None + + # Only patch if the module was imported; keep scope minimal and idempotent. + if mfa is not None: + # pylint: disable=protected-access + if not hasattr(mfa, "_flash_supports_window"): + mfa._flash_supports_window = True + if not hasattr(mfa, "_flash_supports_window_size"): + mfa._flash_supports_window_size = True + # Inject utility function only if absent; ignore if not available in this HF version. + try: + from transformers.utils import ( # type: ignore + is_flash_attn_greater_or_equal as _is_fa_ge, + ) + if not hasattr(mfa, "is_flash_attn_greater_or_equal"): + mfa.is_flash_attn_greater_or_equal = _is_fa_ge + except Exception: + pass + + # Validate ring_flash_attn availability with a focused try/except to avoid misleading error messages. + try: + import ring_flash_attn # noqa: F401 # pylint:disable=unused-import + except ImportError as exception: raise ImportError( "context_parallel_size > 1 but ring_flash_attn is not installed. " "Please install it with `pip install axolotl[ring-flash-attn] " "or `pip install ring-flash-attn>=0.1.4`." ) from exceptionBenefits:
- Reduces global side effects, aligns with earlier feedback to avoid unconditional overrides, and keeps compatibility across transformers versions.
- Removes redundant setattr calls and satisfies Ruff B010.
- Produces accurate error messages when ring_flash_attn is missing.
Note: This echoes prior feedback about unconditional monkeypatching protected internals; the new patch constrains it and avoids double-setting.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
docker/Dockerfile-base(1 hunks)src/axolotl/utils/schemas/validation.py(2 hunks)
🧰 Additional context used
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/validation.py
1261-1265: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
1266-1270: Do not call setattr with a constant attribute value. It is not any safer than normal property access.
Replace setattr with assignment
(B010)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: build-base-uv (126, 12.6.3, 3.11, 2.7.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-uv-base)
- GitHub Check: build-base-uv (128, 12.8.1, 3.11, 2.7.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-uv-base)
- GitHub Check: build-base-uv (128, 12.8.1, 3.11, 2.8.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-uv-base)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
- GitHub Check: build-base (128, 12.8.1, 3.11, 2.7.1, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.7.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
- GitHub Check: build-base (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
- GitHub Check: build-base (124, 12.4.1, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
- GitHub Check: build-base (128, 12.8.1, 3.11, 2.8.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-base)
- GitHub Check: build-base-uv (126, 12.6.3, 3.11, 2.6.0, 7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX, Dockerfile-uv-base)
- GitHub Check: preview
- GitHub Check: pre-commit
| CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \ | ||
| python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \ | ||
| python3 -m pip cache purge |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Avoid forcing CAUSAL_CONV1D to skip CUDA build in a CUDA base image
Line 40 sets CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE for a CUDA-enabled base image. This likely disables GPU kernels for causal_conv1d and can degrade performance or break GPU paths that expect the CUDA extension. If the intent is to leverage GPU acceleration (which the base image suggests), we should not skip the CUDA build by default.
Proposed fix (remove the env var so GPU extensions are built when available):
- CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
+ python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \If you need a CPU-only option for certain builds, make this conditional via a build ARG rather than unconditional in the CUDA image.
Please confirm that we indeed want CPU-only causal_conv1d in this CUDA image. If not, apply the diff above. If yes, consider gating it with a build ARG (e.g., SKIP_CCONV_CUDA_BUILD=true).
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Searching repository for causal_conv1d / CAUSAL_CONV1D_SKIP_CUDA_BUILD ==="
rg -n --hidden --no-ignore -S "CAUSAL_CONV1D_SKIP_CUDA_BUILD|CAUSAL_CONV1D|causal_conv1d" || true
echo
echo "=== Show docker/Dockerfile-base (if present) ==="
if [ -f docker/Dockerfile-base ]; then
sed -n '1,240p' docker/Dockerfile-base || true
else
echo "docker/Dockerfile-base not found at that path. Listing Dockerfiles found:"
rg -n --hidden --no-ignore -S "Dockerfile" || true
fi
echo
echo "=== Search for CUDA-related base images in Dockerfiles / top-level FROM lines ==="
rg -n --hidden --no-ignore -S "FROM .*cuda|FROM .*nvidia|FROM .*pytorch|FROM .*cudnn|FROM .*nvidia/cuda" || true
echo
echo "=== Search CI/workflows for build args or env that mention CAUSAL_CONV1D or skipping CUDA build ==="
rg -n --hidden --no-ignore -S "CAUSAL_CONV1D_SKIP_CUDA_BUILD|CAUSAL_CONV1D|SKIP_CCONV_CUDA_BUILD|SKIP_CCONV|skip_causal|causal_conv1d" || trueLength of output: 5095
Don't force CAUSAL_CONV1D to skip CUDA build in a CUDA base image
This Dockerfile uses an nvidia/cuda base image; unconditionally setting CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE prevents building the CUDA extension (disabling GPU acceleration) and is inconsistent with other Dockerfiles in the repo.
Files to review:
- docker/Dockerfile-base — line with:
CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \ - docker/Dockerfile-base-nightly, docker/Dockerfile-base-next — install causal_conv1d from git without skipping CUDA
- docker/Dockerfile-uv-base — installs causal_conv1d from git (no skip)
- setup.py — lists causal_conv1d as a dependency
Proposed minimal change (remove the forced skip):
- CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
+ python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \If you need a CPU-only build option, gate it with a build ARG instead of forcing it in the CUDA image. Example approach:
- Add near the top:
ARG SKIP_CCONV_CUDA_BUILD="FALSE" - Replace the install line with a conditional in the RUN so the skip is only applied when
SKIP_CCONV_CUDA_BUILDis explicitly set to "TRUE".
Please remove the env var or switch to the build-arg gating to keep GPU support available by default.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \ | |
| python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \ | |
| python3 -m pip cache purge | |
| python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \ | |
| python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \ | |
| python3 -m pip cache purge |
Summary by CodeRabbit
New Features
Bug Fixes
Chores
Refactor
Tests