Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile-base
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ WORKDIR /workspace

RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} torchvision --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge
Comment on lines +40 to 42

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Avoid forcing CAUSAL_CONV1D to skip CUDA build in a CUDA base image

Line 40 sets CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE for a CUDA-enabled base image. This likely disables GPU kernels for causal_conv1d and can degrade performance or break GPU paths that expect the CUDA extension. If the intent is to leverage GPU acceleration (which the base image suggests), we should not skip the CUDA build by default.

Proposed fix (remove the env var so GPU extensions are built when available):

-    CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
+    python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \

If you need a CPU-only option for certain builds, make this conditional via a build ARG rather than unconditional in the CUDA image.

Please confirm that we indeed want CPU-only causal_conv1d in this CUDA image. If not, apply the diff above. If yes, consider gating it with a build ARG (e.g., SKIP_CCONV_CUDA_BUILD=true).


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Searching repository for causal_conv1d / CAUSAL_CONV1D_SKIP_CUDA_BUILD ==="
rg -n --hidden --no-ignore -S "CAUSAL_CONV1D_SKIP_CUDA_BUILD|CAUSAL_CONV1D|causal_conv1d" || true

echo
echo "=== Show docker/Dockerfile-base (if present) ==="
if [ -f docker/Dockerfile-base ]; then
  sed -n '1,240p' docker/Dockerfile-base || true
else
  echo "docker/Dockerfile-base not found at that path. Listing Dockerfiles found:"
  rg -n --hidden --no-ignore -S "Dockerfile" || true
fi

echo
echo "=== Search for CUDA-related base images in Dockerfiles / top-level FROM lines ==="
rg -n --hidden --no-ignore -S "FROM .*cuda|FROM .*nvidia|FROM .*pytorch|FROM .*cudnn|FROM .*nvidia/cuda" || true

echo
echo "=== Search CI/workflows for build args or env that mention CAUSAL_CONV1D or skipping CUDA build ==="
rg -n --hidden --no-ignore -S "CAUSAL_CONV1D_SKIP_CUDA_BUILD|CAUSAL_CONV1D|SKIP_CCONV_CUDA_BUILD|SKIP_CCONV|skip_causal|causal_conv1d" || true

Length of output: 5095


Don't force CAUSAL_CONV1D to skip CUDA build in a CUDA base image

This Dockerfile uses an nvidia/cuda base image; unconditionally setting CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE prevents building the CUDA extension (disabling GPU acceleration) and is inconsistent with other Dockerfiles in the repo.

Files to review:

  • docker/Dockerfile-base — line with: CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
  • docker/Dockerfile-base-nightly, docker/Dockerfile-base-next — install causal_conv1d from git without skipping CUDA
  • docker/Dockerfile-uv-base — installs causal_conv1d from git (no skip)
  • setup.py — lists causal_conv1d as a dependency

Proposed minimal change (remove the forced skip):

-    CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
+    python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \

If you need a CPU-only build option, gate it with a build ARG instead of forcing it in the CUDA image. Example approach:

  • Add near the top: ARG SKIP_CCONV_CUDA_BUILD="FALSE"
  • Replace the install line with a conditional in the RUN so the skip is only applied when SKIP_CCONV_CUDA_BUILD is explicitly set to "TRUE".

Please remove the env var or switch to the build-arg gating to keep GPU support available by default.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge
python3 -m pip install --no-cache-dir causal_conv1d==1.5.2 && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" && \
python3 -m pip cache purge


Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/

# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.46.1
bitsandbytes==0.47.0
# triton 3.4.0 is not compatible with CCE
triton>=3.0.0,<3.4.0
mamba-ssm==1.2.0.post1
Expand All @@ -14,7 +14,7 @@ packaging==23.2

huggingface_hub>=0.33.0
peft==0.17.0
transformers==4.55.0
transformers==4.55.1
tokenizers>=0.21.1
accelerate==1.10.0
datasets==4.0.0
Expand Down
2 changes: 0 additions & 2 deletions src/axolotl/loaders/patch_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,12 +285,10 @@ def _apply_fsdp2_bnb_patches(self):
and self.cfg.adapter == "qlora"
):
from axolotl.monkeypatch.fsdp2_qlora import (
apply_bnb_torch_function_patch,
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)

apply_bnb_torch_function_patch()
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()

Expand Down
61 changes: 0 additions & 61 deletions src/axolotl/monkeypatch/fsdp2_qlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,73 +9,12 @@
import importlib
import inspect

import torch
from torch.nn import Parameter

from axolotl.monkeypatch.utils import detab_code
from axolotl.utils.logging import get_logger

LOG = get_logger(__name__)


def patched_torch_function(cls, func, types, args=(), kwargs=None):
"""
Patched version of Params4bit.__torch_function__ for preserving Params4bit
class identity and attributes.
"""
if kwargs is None:
kwargs = {}

if func in [torch.chunk, torch.split]:
tensor = args[0]
result = Parameter.__torch_function__(func, types, args, kwargs)

if isinstance(result, tuple):
return tuple(
cls(
data=chunk,
requires_grad=tensor.requires_grad,
quant_state=tensor.quant_state,
blocksize=tensor.blocksize,
compress_statistics=tensor.compress_statistics,
quant_type=tensor.quant_type,
quant_storage=tensor.quant_storage,
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)
for chunk in result
)

return cls(
data=result,
requires_grad=tensor.requires_grad,
quant_state=tensor.quant_state,
blocksize=tensor.blocksize,
compress_statistics=tensor.compress_statistics,
quant_type=tensor.quant_type,
quant_storage=tensor.quant_storage,
module=tensor.module,
bnb_quantized=tensor.bnb_quantized,
)

return Parameter.__torch_function__(func, types, args, kwargs)


# pylint: disable=protected-access
def apply_bnb_torch_function_patch():
"""
Patch Params4bit.__torch_function__ using Axolotl-style approach.

Returns:
True if patching succeeded, False otherwise.
"""
from bitsandbytes.nn.modules import Params4bit

Params4bit.__torch_function__ = classmethod(patched_torch_function)

LOG.info("Successfully patched Params4bit.__torch_function__")


# pylint: disable=protected-access
def apply_init_sharded_param_patch():
"""Apply patch to FSDPParam._init_sharded_param to support Params4bit."""
Expand Down
11 changes: 7 additions & 4 deletions src/axolotl/monkeypatch/ring_attn/adapters/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
from ring_flash_attn.adapters.hf_adapter import check_params
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal

try:
try: # pylint: disable=duplicate-code
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
Comment on lines +23 to +31

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Avoid aliasing _flash_supports_window_size (likely a function) to a boolean flag.

If _flash_supports_window is absent and _flash_supports_window_size is a callable, aliasing it to _flash_supports_window and later using it in a boolean context will always evaluate truthy, potentially enabling sliding windows even when unsupported. Prefer a helper that queries the transformers module attribute(s) at runtime and returns a boolean.

Apply this refactor to robustly determine window support:

-try:  # pylint: disable=duplicate-code
-    from transformers.modeling_flash_attention_utils import _flash_supports_window
-except ImportError:
-    try:
-        from transformers.modeling_flash_attention_utils import (
-            _flash_supports_window_size as _flash_supports_window,
-        )
-    except ImportError:
-        _flash_supports_window = True
+try:  # pylint: disable=duplicate-code
+    import transformers.modeling_flash_attention_utils as _fau
+except ImportError:
+    _fau = None  # type: ignore
+
+def _flash_window_supported(sliding_window: int | None) -> bool:
+    # Default to True for ring-flash-attn path if utils are unavailable
+    if _fau is None:
+        return True
+    # Prefer _flash_supports_window if present
+    val = getattr(_fau, "_flash_supports_window", None)
+    try:
+        if callable(val):
+            return bool(val() if sliding_window is None else val(sliding_window))
+        if val is not None:
+            return bool(val)
+    except Exception:
+        pass
+    # Fallback to _flash_supports_window_size
+    val = getattr(_fau, "_flash_supports_window_size", None)
+    try:
+        if callable(val):
+            # If no explicit size is given, be conservative and assume True for ring-attn
+            return bool(val(sliding_window if sliding_window is not None else 0))
+        if val is not None:
+            return bool(val)
+    except Exception:
+        pass
+    # Final fallback for ring-flash-attn
+    return True
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try: # pylint: disable=duplicate-code
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
try: # pylint: disable=duplicate-code
import transformers.modeling_flash_attention_utils as _fau
except ImportError:
_fau = None # type: ignore
def _flash_window_supported(sliding_window: int | None) -> bool:
# Default to True for ring-flash-attn path if utils are unavailable
if _fau is None:
return True
# Prefer _flash_supports_window if present
val = getattr(_fau, "_flash_supports_window", None)
try:
if callable(val):
return bool(val() if sliding_window is None else val(sliding_window))
if val is not None:
return bool(val)
except Exception:
pass
# Fallback to _flash_supports_window_size
val = getattr(_fau, "_flash_supports_window_size", None)
try:
if callable(val):
# If no explicit size is given, be conservative and assume True for ring-attn
return bool(val(sliding_window if sliding_window is not None else 0))
if val is not None:
return bool(val)
except Exception:
pass
# Final fallback for ring-flash-attn
return True


from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

Expand Down
11 changes: 7 additions & 4 deletions src/axolotl/monkeypatch/ring_attn/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
import torch.distributed as dist
from torch.distributed import DeviceMesh

try:
try: # pylint: disable=duplicate-code
from transformers.modeling_flash_attention_utils import _flash_supports_window
except ImportError:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
try:
from transformers.modeling_flash_attention_utils import (
_flash_supports_window_size as _flash_supports_window,
)
except ImportError:
_flash_supports_window = True
Comment on lines +18 to +26

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Same issue: aliasing _flash_supports_window_size to a boolean flag is fragile.

Mirror the robust runtime check pattern here to avoid always-True behavior on callable aliases and to accommodate transformers version drift.

-try:  # pylint: disable=duplicate-code
-    from transformers.modeling_flash_attention_utils import _flash_supports_window
-except ImportError:
-    try:
-        from transformers.modeling_flash_attention_utils import (
-            _flash_supports_window_size as _flash_supports_window,
-        )
-    except ImportError:
-        _flash_supports_window = True
+try:  # pylint: disable=duplicate-code
+    import transformers.modeling_flash_attention_utils as _fau
+except ImportError:
+    _fau = None  # type: ignore
+
+def _flash_window_supported(sliding_window: int | None) -> bool:
+    if _fau is None:
+        return True
+    val = getattr(_fau, "_flash_supports_window", None)
+    try:
+        if callable(val):
+            return bool(val() if sliding_window is None else val(sliding_window))
+        if val is not None:
+            return bool(val)
+    except Exception:
+        pass
+    val = getattr(_fau, "_flash_supports_window_size", None)
+    try:
+        if callable(val):
+            return bool(val(sliding_window if sliding_window is not None else 0))
+        if val is not None:
+            return bool(val)
+    except Exception:
+        pass
+    return True

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/ring_attn/patch.py around lines 18 to 26, the current
fallback aliases _flash_supports_window_size to a boolean which can cause
always-True behavior; instead, detect whether the imported symbol is callable
and, if it is not, wrap the boolean in a small runtime-check function so callers
expecting a function that accepts a window_size still get correct behavior.
Modify the import/fallback logic to: try the primary import, then try importing
_flash_supports_window_size and if that import yields a callable keep it as
_flash_supports_window, otherwise define _flash_supports_window = lambda
window_size: bool(_flash_supports_window_size) (or a simple def that returns the
boolean) and finally ensure a sane default function (e.g., returns True/False)
if both imports fail.


from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.logging import get_logger
Expand Down
21 changes: 19 additions & 2 deletions src/axolotl/utils/schemas/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# pylint: disable=too-many-boolean-expressions

import json
import sys
import tempfile
from pathlib import Path

Expand Down Expand Up @@ -1251,10 +1252,26 @@ def check_context_parallel_size(self):

try:
import transformers.modeling_flash_attention_utils
from transformers.utils import is_flash_attn_greater_or_equal

# pylint: disable=protected-access
transformers.modeling_flash_attention_utils._flash_supports_window_size = (
transformers.modeling_flash_attention_utils._flash_supports_window
transformers.modeling_flash_attention_utils._flash_supports_window = (
True
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"_flash_supports_window",
True,
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"_flash_supports_window_size",
True,
)
setattr(
sys.modules["transformers.modeling_flash_attention_utils"],
"is_flash_attn_greater_or_equal",
is_flash_attn_greater_or_equal,
)
import ring_flash_attn # noqa: F401 # pylint:disable=unused-import
except ImportError as exception:
Expand Down
102 changes: 2 additions & 100 deletions tests/e2e/patched/test_fsdp2_qlora.py
Original file line number Diff line number Diff line change
@@ -1,126 +1,28 @@
"""Integration tests for FSDP Params4bit patches."""
"""Integration tests for FSDP2 Params4bit patches."""

from unittest.mock import Mock, patch

import bitsandbytes as bnb
import pytest
import torch
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam

Comment on lines 3 to 5

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Harden test against missing/changed Torch internals and ensure test isolation (restore patched methods).

  • Importing FSDPParam from a private Torch module at import time is brittle and can break test collection. Use pytest.importorskip inside the test to skip cleanly if the internal path moves/changes.
  • The test permanently mutates FSDPParam methods, which can leak state to other tests. Use pytest’s monkeypatch to capture/restore originals automatically.
  • Optional: Assert idempotency by re-applying the patches and ensuring the method objects don’t change a second time.

Apply the following changes:

-import pytest
-from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
+import pytest
-    @pytest.mark.integration
-    def test_fsdp2_init_patches(self):
-        """Test that all patches can be applied together."""
+    @pytest.mark.integration
+    def test_fsdp2_init_patches(self, monkeypatch):
+        """Verify FSDP2 init patches are applied and idempotent."""
+        # Import Torch internal lazily and skip if unavailable/changed
+        fsdp_mod = pytest.importorskip("torch.distributed.fsdp._fully_shard._fsdp_param")
+        FSDPParam = fsdp_mod.FSDPParam
         from axolotl.monkeypatch.fsdp2_qlora import (
             apply_init_sharded_param_patch,
             apply_init_unsharded_param_patch,
         )
 
-        # pylint: disable=protected-access
-        original_init_sharded = FSDPParam._init_sharded_param
-        original_init_unsharded = FSDPParam.init_unsharded_param
+        # Ensure test isolation by restoring originals after the test
+        # monkeypatch.setattr records the current value and restores it at teardown.
+        monkeypatch.setattr(
+            FSDPParam, "_init_sharded_param", FSDPParam._init_sharded_param, raising=False
+        )
+        monkeypatch.setattr(
+            FSDPParam, "init_unsharded_param", FSDPParam.init_unsharded_param, raising=False
+        )
+
+        # pylint: disable=protected-access
+        original_init_sharded = FSDPParam._init_sharded_param
+        original_init_unsharded = FSDPParam.init_unsharded_param
 
         # Apply patches
         apply_init_sharded_param_patch()
         apply_init_unsharded_param_patch()
 
         assert (
             # pylint: disable=protected-access
             FSDPParam._init_sharded_param
             != original_init_sharded
         ), "_init_sharded_param was not patched"
         assert (
             FSDPParam.init_unsharded_param != original_init_unsharded
         ), "init_unsharded_param was not patched"
+
+        # Optional: patches should be idempotent (re-applying should not re-wrap)
+        patched_init_sharded = FSDPParam._init_sharded_param
+        patched_init_unsharded = FSDPParam.init_unsharded_param
+        apply_init_sharded_param_patch()
+        apply_init_unsharded_param_patch()
+        assert (
+            FSDPParam._init_sharded_param is patched_init_sharded
+        ), "_init_sharded_param patch should be idempotent"
+        assert (
+            FSDPParam.init_unsharded_param is patched_init_unsharded
+        ), "init_unsharded_param patch should be idempotent"

Follow-up:

  • If these patches may be pre-applied by other tests/modules, confirm the patch functions are no-ops when called repeatedly (idempotent). If not, consider adding guards in the patch implementations, or adapt the assertions to tolerate an already-patched state.

Also applies to: 11-33


🏁 Script executed:

#!/bin/bash
set -e

echo "===== Showing tests/e2e/patched/test_fsdp2_qlora.py ====="
sed -n '1,200p' tests/e2e/patched/test_fsdp2_qlora.py || echo "File not found: tests/e2e/patched/test_fsdp2_qlora.py"

echo
echo "===== Locating patch function definitions ====="
rg -n -C3 "def apply_init_sharded_param_patch" -n . || echo "No apply_init_sharded_param_patch definition found"
rg -n -C3 "def apply_init_unsharded_param_patch" -n . || echo "No apply_init_unsharded_param_patch definition found"

Length of output: 2639


🏁 Script executed:

#!/bin/bash
set -e

echo "===== Showing the first 200 lines of src/axolotl/monkeypatch/fsdp2_qlora.py ====="
sed -n '1,200p' src/axolotl/monkeypatch/fsdp2_qlora.py || echo "File not found"

Length of output: 5879


Harden test against Torch internal changes and ensure isolation

  • Avoid importing FSDPParam at module load from a private path; this can break collection if Torch internals move.
  • Lazily import via pytest.importorskip inside the test to skip cleanly on missing/renamed internals.
  • Use the monkeypatch fixture to snapshot and restore FSDPParam methods, preventing cross-test leakage.
  • (Optional) Re-apply the patches to assert idempotency.

Proposed changes:

-import pytest
-from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
+import pytest
-    @pytest.mark.integration
-    def test_fsdp2_init_patches(self):
-        """Test that all patches can be applied together."""
+    @pytest.mark.integration
+    def test_fsdp2_init_patches(self, monkeypatch):
+        """Verify FSDPParam patches apply and are idempotent."""
+        # Lazily import Torch internals and skip if unavailable
+        fsdp_mod = pytest.importorskip("torch.distributed.fsdp._fully_shard._fsdp_param")
+        FSDPParam = fsdp_mod.FSDPParam
+
+        # Snapshot originals for automatic restore
+        monkeypatch.setattr(
+            FSDPParam, "_init_sharded_param", FSDPParam._init_sharded_param, raising=False
+        )
+        monkeypatch.setattr(
+            FSDPParam, "init_unsharded_param", FSDPParam.init_unsharded_param, raising=False
+        )
+
+        from axolotl.monkeypatch.fsdp2_qlora import (
+            apply_init_sharded_param_patch,
+            apply_init_unsharded_param_patch,
+        )
+
+        # Capture originals
+        original_init_sharded = FSDPParam._init_sharded_param
+        original_init_unsharded = FSDPParam.init_unsharded_param
+
+        # Apply patches
+        apply_init_sharded_param_patch()
+        apply_init_unsharded_param_patch()
+
+        # Confirm methods changed
+        assert FSDPParam._init_sharded_param is not original_init_sharded, \
+            "_init_sharded_param was not patched"
+        assert FSDPParam.init_unsharded_param is not original_init_unsharded, \
+            "init_unsharded_param was not patched"
+
+        # Optional: ensure idempotency
+        patched_sharded = FSDPParam._init_sharded_param
+        patched_unsharded = FSDPParam.init_unsharded_param
+        apply_init_sharded_param_patch()
+        apply_init_unsharded_param_patch()
+        assert FSDPParam._init_sharded_param is patched_sharded, \
+            "_init_sharded_param patch should be idempotent"
+        assert FSDPParam.init_unsharded_param is patched_unsharded, \
+            "init_unsharded_param patch should be idempotent"

Follow-up: if your patch functions may already have been applied elsewhere, ensure they’re no-ops on repeated calls or adapt the idempotency checks accordingly.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import pytest
import torch
from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
# --- before (tests/e2e/patched/test_fsdp2_qlora.py) ---
-import pytest
-from torch.distributed.fsdp._fully_shard._fsdp_param import FSDPParam
+import pytest
# ... other imports or code ...
class TestFSDP2QLoRA:
- @pytest.mark.integration
- def test_fsdp2_init_patches(self):
- """Test that all patches can be applied together."""
+ @pytest.mark.integration
+ def test_fsdp2_init_patches(self, monkeypatch):
+ """Verify FSDPParam patches apply and are idempotent."""
+ # Lazily import Torch internals and skip if unavailable
+ fsdp_mod = pytest.importorskip("torch.distributed.fsdp._fully_shard._fsdp_param")
+ FSDPParam = fsdp_mod.FSDPParam
+
+ # Snapshot originals for automatic restore
+ monkeypatch.setattr(
+ FSDPParam, "_init_sharded_param", FSDPParam._init_sharded_param, raising=False
+ )
+ monkeypatch.setattr(
+ FSDPParam, "init_unsharded_param", FSDPParam.init_unsharded_param, raising=False
+ )
+
+ from axolotl.monkeypatch.fsdp2_qlora import (
+ apply_init_sharded_param_patch,
+ apply_init_unsharded_param_patch,
+ )
+
+ # Capture originals
+ original_init_sharded = FSDPParam._init_sharded_param
+ original_init_unsharded = FSDPParam.init_unsharded_param
+
+ # Apply patches
+ apply_init_sharded_param_patch()
+ apply_init_unsharded_param_patch()
+
+ # Confirm methods changed
+ assert FSDPParam._init_sharded_param is not original_init_sharded, \
+ "_init_sharded_param was not patched"
+ assert FSDPParam.init_unsharded_param is not original_init_unsharded, \
+ "init_unsharded_param was not patched"
+
+ # Optional: ensure idempotency
+ patched_sharded = FSDPParam._init_sharded_param
+ patched_unsharded = FSDPParam.init_unsharded_param
+ apply_init_sharded_param_patch()
+ apply_init_unsharded_param_patch()
+ assert FSDPParam._init_sharded_param is patched_sharded, \
+ "_init_sharded_param patch should be idempotent"
+ assert FSDPParam.init_unsharded_param is patched_unsharded, \
+ "init_unsharded_param patch should be idempotent"

from axolotl.monkeypatch.fsdp2_qlora import (
apply_bnb_torch_function_patch,
patched_torch_function,
)


@pytest.fixture
def mock_params4bit():
"""Create a mock Params4bit instance with test attributes."""
mock_instance = Mock()
mock_instance.requires_grad = True
mock_instance.quant_state = "test_state"
mock_instance.blocksize = 128
mock_instance.compress_statistics = True
mock_instance.quant_type = "fp4"
mock_instance.quant_storage = "test_storage"
mock_instance.module = "test_module"
mock_instance.bnb_quantized = True
return mock_instance


class TestBnbTorchFunctionPatch:
"""Test the Params4bit.__torch_function__ patch."""

def test_apply_patch(self):
"""Test that the patch can be applied."""
with patch("bitsandbytes.nn.modules.Params4bit") as mock_cls:
apply_bnb_torch_function_patch()
assert hasattr(mock_cls, "__torch_function__")
assert isinstance(mock_cls.__torch_function__, classmethod)

# pylint: disable=redefined-outer-name
def test_torch_chunk_preserves_attributes(self, mock_params4bit):
"""Test that torch.chunk preserves Params4bit attributes."""
mock_cls = Mock()
chunks = (torch.tensor([1, 2]), torch.tensor([3, 4]))

with patch("torch.nn.Parameter.__torch_function__", return_value=chunks):
result = patched_torch_function(
mock_cls,
torch.chunk,
(type(mock_params4bit),),
args=(mock_params4bit, 2),
)

assert isinstance(result, tuple)
assert len(result) == 2

# Check that Params4bit constructor was called with preserved attributes
assert mock_cls.call_count == 2
for call in mock_cls.call_args_list:
kwargs = call[1]
assert kwargs["requires_grad"] == mock_params4bit.requires_grad
assert kwargs["quant_state"] == mock_params4bit.quant_state
assert kwargs["blocksize"] == mock_params4bit.blocksize

# pylint: disable=redefined-outer-name
def test_other_functions_fallback(self, mock_params4bit):
"""Test that non-chunk/split functions use Parameter fallback."""
mock_cls = Mock()
fallback_result = torch.tensor([5, 6, 7])

with patch(
"torch.nn.Parameter.__torch_function__", return_value=fallback_result
) as mock_fallback:
result = patched_torch_function(
mock_cls, torch.add, (type(mock_params4bit),), args=(mock_params4bit, 1)
)

# Should call Parameter.__torch_function__ and return its result
mock_fallback.assert_called_once()
assert result is fallback_result
mock_cls.assert_not_called()


class TestFSDPPatchIntegration:
"""Test FSDP patch integration."""

@pytest.mark.integration
def test_all_patches_together(self):
def test_fsdp2_init_patches(self):
"""Test that all patches can be applied together."""
from axolotl.monkeypatch.fsdp2_qlora import (
apply_init_sharded_param_patch,
apply_init_unsharded_param_patch,
)

# Store original methods before patching
original_torch_function = getattr(
bnb.nn.modules.Params4bit, "__torch_function__", None
)

# pylint: disable=protected-access
original_init_sharded = FSDPParam._init_sharded_param
original_init_unsharded = FSDPParam.init_unsharded_param

# Apply patches
apply_bnb_torch_function_patch()
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()

# Verify patches were applied
current_torch_function = getattr(
bnb.nn.modules.Params4bit, "__torch_function__", None
)
if original_torch_function is not None:
assert (
current_torch_function != original_torch_function
), "Params4bit.__torch_function__ was not patched"
else:
assert (
current_torch_function is not None
), "Params4bit.__torch_function__ was not added"

# Check that FSDP methods were patched
assert (
# pylint: disable=protected-access
FSDPParam._init_sharded_param
Expand Down
Loading