-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
upgrade transformers==4.55.1 and bitsandbytes==0.47.0 #3064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
35d7594
50047f5
f585469
5558249
e6d722f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -20,12 +20,15 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from ring_flash_attn.adapters.hf_adapter import check_params | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: # pylint: disable=duplicate-code | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformers.modeling_flash_attention_utils import _flash_supports_window | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except ImportError: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformers.modeling_flash_attention_utils import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _flash_supports_window_size as _flash_supports_window, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformers.modeling_flash_attention_utils import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _flash_supports_window_size as _flash_supports_window, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except ImportError: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _flash_supports_window = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+23
to
+31
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid aliasing _flash_supports_window_size (likely a function) to a boolean flag. If _flash_supports_window is absent and _flash_supports_window_size is a callable, aliasing it to _flash_supports_window and later using it in a boolean context will always evaluate truthy, potentially enabling sliding windows even when unsupported. Prefer a helper that queries the transformers module attribute(s) at runtime and returns a boolean. Apply this refactor to robustly determine window support: -try: # pylint: disable=duplicate-code
- from transformers.modeling_flash_attention_utils import _flash_supports_window
-except ImportError:
- try:
- from transformers.modeling_flash_attention_utils import (
- _flash_supports_window_size as _flash_supports_window,
- )
- except ImportError:
- _flash_supports_window = True
+try: # pylint: disable=duplicate-code
+ import transformers.modeling_flash_attention_utils as _fau
+except ImportError:
+ _fau = None # type: ignore
+
+def _flash_window_supported(sliding_window: int | None) -> bool:
+ # Default to True for ring-flash-attn path if utils are unavailable
+ if _fau is None:
+ return True
+ # Prefer _flash_supports_window if present
+ val = getattr(_fau, "_flash_supports_window", None)
+ try:
+ if callable(val):
+ return bool(val() if sliding_window is None else val(sliding_window))
+ if val is not None:
+ return bool(val)
+ except Exception:
+ pass
+ # Fallback to _flash_supports_window_size
+ val = getattr(_fau, "_flash_supports_window_size", None)
+ try:
+ if callable(val):
+ # If no explicit size is given, be conservative and assume True for ring-attn
+ return bool(val(sliding_window if sliding_window is not None else 0))
+ if val is not None:
+ return bool(val)
+ except Exception:
+ pass
+ # Final fallback for ring-flash-attn
+ return True📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,12 +15,15 @@ | |
| import torch.distributed as dist | ||
| from torch.distributed import DeviceMesh | ||
|
|
||
| try: | ||
| try: # pylint: disable=duplicate-code | ||
| from transformers.modeling_flash_attention_utils import _flash_supports_window | ||
| except ImportError: | ||
| from transformers.modeling_flash_attention_utils import ( | ||
| _flash_supports_window_size as _flash_supports_window, | ||
| ) | ||
| try: | ||
| from transformers.modeling_flash_attention_utils import ( | ||
| _flash_supports_window_size as _flash_supports_window, | ||
| ) | ||
| except ImportError: | ||
| _flash_supports_window = True | ||
|
Comment on lines
+18
to
+26
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same issue: aliasing _flash_supports_window_size to a boolean flag is fragile. Mirror the robust runtime check pattern here to avoid always-True behavior on callable aliases and to accommodate transformers version drift. -try: # pylint: disable=duplicate-code
- from transformers.modeling_flash_attention_utils import _flash_supports_window
-except ImportError:
- try:
- from transformers.modeling_flash_attention_utils import (
- _flash_supports_window_size as _flash_supports_window,
- )
- except ImportError:
- _flash_supports_window = True
+try: # pylint: disable=duplicate-code
+ import transformers.modeling_flash_attention_utils as _fau
+except ImportError:
+ _fau = None # type: ignore
+
+def _flash_window_supported(sliding_window: int | None) -> bool:
+ if _fau is None:
+ return True
+ val = getattr(_fau, "_flash_supports_window", None)
+ try:
+ if callable(val):
+ return bool(val() if sliding_window is None else val(sliding_window))
+ if val is not None:
+ return bool(val)
+ except Exception:
+ pass
+ val = getattr(_fau, "_flash_supports_window_size", None)
+ try:
+ if callable(val):
+ return bool(val(sliding_window if sliding_window is not None else 0))
+ if val is not None:
+ return bool(val)
+ except Exception:
+ pass
+ return True
🤖 Prompt for AI Agents |
||
|
|
||
| from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids | ||
| from axolotl.utils.logging import get_logger | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,126 +1,28 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Integration tests for FSDP Params4bit patches.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Integration tests for FSDP2 Params4bit patches.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from unittest.mock import Mock, patch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import bitsandbytes as bnb | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
3
to
5
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainHarden test against missing/changed Torch internals and ensure test isolation (restore patched methods).
Apply the following changes: -import pytest
-from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
+import pytest- @pytest.mark.integration
- def test_fsdp2_init_patches(self):
- """Test that all patches can be applied together."""
+ @pytest.mark.integration
+ def test_fsdp2_init_patches(self, monkeypatch):
+ """Verify FSDP2 init patches are applied and idempotent."""
+ # Import Torch internal lazily and skip if unavailable/changed
+ fsdp_mod = pytest.importorskip("torch.distributed.fsdp._fully_shard._fsdp_param")
+ FSDPParam = fsdp_mod.FSDPParam
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)
- # pylint: disable=protected-access
- original_init_sharded = FSDPParam._init_sharded_param
- original_init_unsharded = FSDPParam.init_unsharded_param
+ # Ensure test isolation by restoring originals after the test
+ # monkeypatch.setattr records the current value and restores it at teardown.
+ monkeypatch.setattr(
+ FSDPParam, "_init_sharded_param", FSDPParam._init_sharded_param, raising=False
+ )
+ monkeypatch.setattr(
+ FSDPParam, "init_unsharded_param", FSDPParam.init_unsharded_param, raising=False
+ )
+
+ # pylint: disable=protected-access
+ original_init_sharded = FSDPParam._init_sharded_param
+ original_init_unsharded = FSDPParam.init_unsharded_param
# Apply patches
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
assert (
# pylint: disable=protected-access
FSDPParam._init_sharded_param
!= original_init_sharded
), "_init_sharded_param was not patched"
assert (
FSDPParam.init_unsharded_param != original_init_unsharded
), "init_unsharded_param was not patched"
+
+ # Optional: patches should be idempotent (re-applying should not re-wrap)
+ patched_init_sharded = FSDPParam._init_sharded_param
+ patched_init_unsharded = FSDPParam.init_unsharded_param
+ apply_init_sharded_param_patch()
+ apply_init_unsharded_param_patch()
+ assert (
+ FSDPParam._init_sharded_param is patched_init_sharded
+ ), "_init_sharded_param patch should be idempotent"
+ assert (
+ FSDPParam.init_unsharded_param is patched_init_unsharded
+ ), "init_unsharded_param patch should be idempotent"Follow-up:
Also applies to: 11-33 🏁 Script executed: #!/bin/bash
set -e
echo "===== Showing tests/e2e/patched/test_fsdp2_qlora.py ====="
sed -n '1,200p' tests/e2e/patched/test_fsdp2_qlora.py || echo "File not found: tests/e2e/patched/test_fsdp2_qlora.py"
echo
echo "===== Locating patch function definitions ====="
rg -n -C3 "def apply_init_sharded_param_patch" -n . || echo "No apply_init_sharded_param_patch definition found"
rg -n -C3 "def apply_init_unsharded_param_patch" -n . || echo "No apply_init_unsharded_param_patch definition found"Length of output: 2639 🏁 Script executed: #!/bin/bash
set -e
echo "===== Showing the first 200 lines of src/axolotl/monkeypatch/fsdp2_qlora.py ====="
sed -n '1,200p' src/axolotl/monkeypatch/fsdp2_qlora.py || echo "File not found"Length of output: 5879 Harden test against Torch internal changes and ensure isolation
Proposed changes: -import pytest
-from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
+import pytest- @pytest.mark.integration
- def test_fsdp2_init_patches(self):
- """Test that all patches can be applied together."""
+ @pytest.mark.integration
+ def test_fsdp2_init_patches(self, monkeypatch):
+ """Verify FSDPParam patches apply and are idempotent."""
+ # Lazily import Torch internals and skip if unavailable
+ fsdp_mod = pytest.importorskip("torch.distributed.fsdp._fully_shard._fsdp_param")
+ FSDPParam = fsdp_mod.FSDPParam
+
+ # Snapshot originals for automatic restore
+ monkeypatch.setattr(
+ FSDPParam, "_init_sharded_param", FSDPParam._init_sharded_param, raising=False
+ )
+ monkeypatch.setattr(
+ FSDPParam, "init_unsharded_param", FSDPParam.init_unsharded_param, raising=False
+ )
+
+ from axolotl.monkeypatch.fsdp2_qlora import (
+ apply_init_sharded_param_patch,
+ apply_init_unsharded_param_patch,
+ )
+
+ # Capture originals
+ original_init_sharded = FSDPParam._init_sharded_param
+ original_init_unsharded = FSDPParam.init_unsharded_param
+
+ # Apply patches
+ apply_init_sharded_param_patch()
+ apply_init_unsharded_param_patch()
+
+ # Confirm methods changed
+ assert FSDPParam._init_sharded_param is not original_init_sharded, \
+ "_init_sharded_param was not patched"
+ assert FSDPParam.init_unsharded_param is not original_init_unsharded, \
+ "init_unsharded_param was not patched"
+
+ # Optional: ensure idempotency
+ patched_sharded = FSDPParam._init_sharded_param
+ patched_unsharded = FSDPParam.init_unsharded_param
+ apply_init_sharded_param_patch()
+ apply_init_unsharded_param_patch()
+ assert FSDPParam._init_sharded_param is patched_sharded, \
+ "_init_sharded_param patch should be idempotent"
+ assert FSDPParam.init_unsharded_param is patched_unsharded, \
+ "init_unsharded_param patch should be idempotent"Follow-up: if your patch functions may already have been applied elsewhere, ensure they’re no-ops on repeated calls or adapt the idempotency checks accordingly. 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from axolotl.monkeypatch.fsdp2_qlora import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| apply_bnb_torch_function_patch, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| patched_torch_function, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @pytest.fixture | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def mock_params4bit(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Create a mock Params4bit instance with test attributes.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_instance = Mock() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_instance.requires_grad = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_instance.quant_state = "test_state" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_instance.blocksize = 128 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_instance.compress_statistics = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_instance.quant_type = "fp4" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_instance.quant_storage = "test_storage" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_instance.module = "test_module" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_instance.bnb_quantized = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return mock_instance | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class TestBnbTorchFunctionPatch: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Test the Params4bit.__torch_function__ patch.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_apply_patch(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Test that the patch can be applied.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with patch("bitsandbytes.nn.modules.Params4bit") as mock_cls: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| apply_bnb_torch_function_patch() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert hasattr(mock_cls, "__torch_function__") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert isinstance(mock_cls.__torch_function__, classmethod) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # pylint: disable=redefined-outer-name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_torch_chunk_preserves_attributes(self, mock_params4bit): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Test that torch.chunk preserves Params4bit attributes.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_cls = Mock() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| chunks = (torch.tensor([1, 2]), torch.tensor([3, 4])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with patch("torch.nn.Parameter.__torch_function__", return_value=chunks): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result = patched_torch_function( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_cls, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch.chunk, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| (type(mock_params4bit),), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args=(mock_params4bit, 2), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert isinstance(result, tuple) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert len(result) == 2 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Check that Params4bit constructor was called with preserved attributes | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert mock_cls.call_count == 2 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for call in mock_cls.call_args_list: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| kwargs = call[1] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert kwargs["requires_grad"] == mock_params4bit.requires_grad | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert kwargs["quant_state"] == mock_params4bit.quant_state | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert kwargs["blocksize"] == mock_params4bit.blocksize | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # pylint: disable=redefined-outer-name | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_other_functions_fallback(self, mock_params4bit): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Test that non-chunk/split functions use Parameter fallback.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_cls = Mock() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fallback_result = torch.tensor([5, 6, 7]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| with patch( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "torch.nn.Parameter.__torch_function__", return_value=fallback_result | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) as mock_fallback: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| result = patched_torch_function( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_cls, torch.add, (type(mock_params4bit),), args=(mock_params4bit, 1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Should call Parameter.__torch_function__ and return its result | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_fallback.assert_called_once() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert result is fallback_result | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mock_cls.assert_not_called() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class TestFSDPPatchIntegration: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Test FSDP patch integration.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| @pytest.mark.integration | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_all_patches_together(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def test_fsdp2_init_patches(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Test that all patches can be applied together.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from axolotl.monkeypatch.fsdp2_qlora import ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| apply_init_sharded_param_patch, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| apply_init_unsharded_param_patch, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Store original methods before patching | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| original_torch_function = getattr( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bnb.nn.modules.Params4bit, "__torch_function__", None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # pylint: disable=protected-access | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| original_init_sharded = FSDPParam._init_sharded_param | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| original_init_unsharded = FSDPParam.init_unsharded_param | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Apply patches | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| apply_bnb_torch_function_patch() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| apply_init_sharded_param_patch() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| apply_init_unsharded_param_patch() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Verify patches were applied | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| current_torch_function = getattr( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| bnb.nn.modules.Params4bit, "__torch_function__", None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if original_torch_function is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| current_torch_function != original_torch_function | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), "Params4bit.__torch_function__ was not patched" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| current_torch_function is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ), "Params4bit.__torch_function__ was not added" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Check that FSDP methods were patched | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # pylint: disable=protected-access | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| FSDPParam._init_sharded_param | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Avoid forcing CAUSAL_CONV1D to skip CUDA build in a CUDA base image
Line 40 sets CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE for a CUDA-enabled base image. This likely disables GPU kernels for causal_conv1d and can degrade performance or break GPU paths that expect the CUDA extension. If the intent is to leverage GPU acceleration (which the base image suggests), we should not skip the CUDA build by default.
Proposed fix (remove the env var so GPU extensions are built when available):
If you need a CPU-only option for certain builds, make this conditional via a build ARG rather than unconditional in the CUDA image.
Please confirm that we indeed want CPU-only causal_conv1d in this CUDA image. If not, apply the diff above. If yes, consider gating it with a build ARG (e.g., SKIP_CCONV_CUDA_BUILD=true).
🏁 Script executed:
Length of output: 5095
Don't force CAUSAL_CONV1D to skip CUDA build in a CUDA base image
This Dockerfile uses an nvidia/cuda base image; unconditionally setting CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE prevents building the CUDA extension (disabling GPU acceleration) and is inconsistent with other Dockerfiles in the repo.
Files to review:
CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \Proposed minimal change (remove the forced skip):
If you need a CPU-only build option, gate it with a build ARG instead of forcing it in the CUDA image. Example approach:
ARG SKIP_CCONV_CUDA_BUILD="FALSE"SKIP_CCONV_CUDA_BUILDis explicitly set to "TRUE".Please remove the env var or switch to the build-arg gating to keep GPU support available by default.
📝 Committable suggestion