Skip to content

make multipack sampler patch explicit#3096

Merged
djsaunde merged 2 commits into
mainfrom
explicit-patch
Aug 22, 2025
Merged

make multipack sampler patch explicit#3096
djsaunde merged 2 commits into
mainfrom
explicit-patch

Conversation

@djsaunde

@djsaunde djsaunde commented Aug 22, 2025

Copy link
Copy Markdown
Collaborator

Description

Implicit patch was bothering me. Plus, centralized the application to our patch manager.

As an aside, we can probably fix some test isolation issues by removing applied patches after tests execute, like I'm doing here.

Motivation and Context

How has this been tested?

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features
    • Added dynamic DataLoader patching to support sample-packed batches; automatically applied when sample packing is enabled and can be toggled on/off.
  • Improvements
    • Reduced import-time side effects for more predictable startup behavior.
    • Enhanced reliability of sample packing by ensuring patches are applied consistently across workers.
  • Tests
    • Updated tests to explicitly enable/disable the DataLoader patch with guaranteed cleanup for more stable test runs.

@djsaunde djsaunde requested a review from winglian August 22, 2025 15:14
@djsaunde djsaunde self-assigned this Aug 22, 2025
@djsaunde djsaunde marked this pull request as ready for review August 22, 2025 15:14
@coderabbitai

coderabbitai Bot commented Aug 22, 2025

Copy link
Copy Markdown
Contributor
📝 Walkthrough

Walkthrough

Adds a controlled, toggleable monkeypatch for PyTorch DataLoader to handle multipack/sample-packed batch indexes, integrates applying that patch in PatchManager when sample_packing is enabled, removes an import-time side effect from datasets, and updates tests to explicitly apply/remove the patch with guaranteed cleanup.

Changes

Cohort / File(s) Summary
Monkeypatch for multipack dataloader
src/axolotl/monkeypatch/data/batch_dataset_fetcher.py
Implements custom _MapDatasetFetcher.fetch for nested/batched indexes, adds module-level patch state, and exposes patch orchestration APIs: patch_fetchers(), apply_multipack_dataloader_patch(), and remove_multipack_dataloader_patch(). Adds a patched worker loop to ensure worker processes apply the patch.
PatchManager integration
src/axolotl/loaders/patch_manager.py
Calls apply_multipack_dataloader_patch() via a local import when cfg.sample_packing is true during the multipack patch flow; logs the action.
Import side-effect cleanup
src/axolotl/common/datasets.py
Removes the top-level import of axolotl.monkeypatch.data.batch_dataset_fetcher to avoid import-time side effects.
Tests update
tests/test_packed_batch_sampler.py
Explicitly imports and calls apply_multipack_dataloader_patch() and remove_multipack_dataloader_patch(); wraps test assertions in try/finally to ensure cleanup.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • winglian
  • SalmanMohammadi

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.


📜 Recent review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 311b4c2 and a7cb761.

📒 Files selected for processing (1)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/loaders/patch_manager.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch explicit-patch

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/axolotl/monkeypatch/data/batch_dataset_fetcher.py (1)

20-44: Guard against non-batched indices to prevent crashes when auto_collation=False

Accessing possibly_batched_index[0] unconditionally will raise when auto_collation is False (e.g., DataLoader with batch_size=None). Since this patch is process-wide, it can break unrelated DataLoaders in the same process. Detect multipack only when auto_collation is on and the first element is a sequence (list/tuple).

Apply this diff to make the fetcher robust and widen support beyond list:

-    def fetch(self, possibly_batched_index):
-        if isinstance(possibly_batched_index[0], list):
-            # Handle nested structure from MultipackBatchSampler
-            data = [None for i in possibly_batched_index]
-            for i, possibly_batched_index_ in enumerate(possibly_batched_index):
-                if self.auto_collation:
-                    if (
-                        hasattr(self.dataset, "__getitems__")
-                        and self.dataset.__getitems__
-                    ):
-                        data[i] = self.dataset.__getitems__(possibly_batched_index_)
-                    else:
-                        data[i] = [self.dataset[idx] for idx in possibly_batched_index_]
-                else:
-                    data[i] = self.dataset[possibly_batched_index_]
-        else:
-            # Standard batch handling
-            if self.auto_collation:
-                if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
-                    data = self.dataset.__getitems__(possibly_batched_index)
-                else:
-                    data = [self.dataset[idx] for idx in possibly_batched_index]
-            else:
-                data = self.dataset[possibly_batched_index]
+    def fetch(self, possibly_batched_index):
+        # Multipack only if auto-collation is on and batch is nested (e.g., [[...], [...], ...])
+        is_multipack = (
+            self.auto_collation
+            and isinstance(possibly_batched_index, (list, tuple))
+            and len(possibly_batched_index) > 0
+            and isinstance(possibly_batched_index[0], (list, tuple))
+        )
+
+        if is_multipack:
+            # Handle nested structure from MultipackBatchSampler
+            data = [None] * len(possibly_batched_index)
+            for i, pack_indices in enumerate(possibly_batched_index):
+                if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
+                    data[i] = self.dataset.__getitems__(pack_indices)
+                else:
+                    data[i] = [self.dataset[idx] for idx in pack_indices]
+        else:
+            # Standard batch handling or single index
+            if self.auto_collation:
+                if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
+                    data = self.dataset.__getitems__(possibly_batched_index)
+                else:
+                    data = [self.dataset[idx] for idx in possibly_batched_index]
+            else:
+                data = self.dataset[possibly_batched_index]
         return self.collate_fn(data)

Additional nits:

  • Replace for i in possibly_batched_index with [None] * len(possibly_batched_index) (applied above).
  • Prefer _ over the unused loop variable name when appropriate.
🧹 Nitpick comments (5)
src/axolotl/monkeypatch/data/batch_dataset_fetcher.py (2)

59-79: Optional: make patch/unpatch thread-safe and add trace logging

If multiple threads (or repeated initializations) can touch this patching surface, a simple RLock plus debug logs will make behavior deterministic and easier to reason about.

Language-specific snippet to add at module top (outside the shown ranges):

import threading
_LOCK = threading.RLock()

Then wrap bodies:

-    if _IS_PATCHED:
-        return
+    with _LOCK:
+        if _IS_PATCHED:
+            return
         # ... existing body ...
-    _IS_PATCHED = True
+        _IS_PATCHED = True

Apply the same pattern in remove_multipack_dataloader_patch().


81-99: Good: complete restoration of both fetcher references and worker loop

Restoring both _utils.fetch and dataloader._utils.fetch plus the original worker loop prevents sticky global side effects between tests/runs. Consider nulling the stored originals after restoration to discourage accidental reuse and aid GC.

     if _ORIGINAL_WORKER_LOOP:
         torch.utils.data._utils.worker._worker_loop = _ORIGINAL_WORKER_LOOP
 
-    _IS_PATCHED = False
+    _IS_PATCHED = False
+    # Optional: clear originals
+    # _ORIGINAL_MAP_DATASET_FETCHER = None
+    # _ORIGINAL_WORKER_LOOP = None
tests/test_packed_batch_sampler.py (2)

112-113: Verify assertions target shape semantics rather than total elements

numel() <= batch_size * max_seq_length couples correctness to total elements and can be brittle if the collator changes batch dimensionality in the future (e.g., extra packs or ragged stacking). Prefer checking per-dimension constraints.

-                assert batch["input_ids"].numel() <= batch_size * max_seq_length
-                assert batch["input_ids"].shape[1] == max_seq_length
+                # Expect sequence length to be padded/clamped to max_seq_length
+                assert batch["input_ids"].shape[1] == max_seq_length
+                # Expect effective batch dimension not to exceed requested batch_size
+                assert batch["input_ids"].shape[0] <= batch_size

If V2BatchSamplerDataCollatorForSeq2Seq intentionally flattens packs into the batch dimension, the second assertion remains valid; if not, adjust to the appropriate dimension.


110-121: Good: guaranteed cleanup via finally; consider a reusable fixture to DRY this across tests

The try/finally ensures the global monkeypatch is removed even on failure. If more tests need this patch, a function-scoped fixture reduces repetition and centralizes cleanup.

Add a fixture (outside this file’s changed hunk):

import pytest
from axolotl.monkeypatch.data.batch_dataset_fetcher import (
    apply_multipack_dataloader_patch,
    remove_multipack_dataloader_patch,
)

@pytest.fixture
def multipack_dataloader_patch():
    apply_multipack_dataloader_patch()
    try:
        yield
    finally:
        remove_multipack_dataloader_patch()

Then use it:

def test_packing(..., multipack_dataloader_patch):
    ...
src/axolotl/loaders/patch_manager.py (1)

281-290: Harden import and add idempotent logging; consider providing a symmetric unpatch

Import errors or internal layout changes shouldn’t break model loading. Wrap the import/call in a try/except with a clear log. Also consider exposing a matching teardown path (e.g., a reset_patches() on PatchManager) that calls remove_multipack_dataloader_patch() to avoid sticky state in long-lived processes.

     def _apply_multipack_dataloader_patch(self):
         """Apply multipack dataloader patch if sample packing is enabled."""
         if self.cfg.sample_packing:
-            from axolotl.monkeypatch.data.batch_dataset_fetcher import (
-                apply_multipack_dataloader_patch,
-            )
-
-            LOG.info("Applying multipack dataloader patch for sample packing...")
-            apply_multipack_dataloader_patch()
+            try:
+                from axolotl.monkeypatch.data.batch_dataset_fetcher import (
+                    apply_multipack_dataloader_patch,
+                )
+                LOG.info("Applying multipack dataloader patch for sample packing...")
+                apply_multipack_dataloader_patch()
+            except Exception as e:
+                LOG.warning("Failed to apply multipack dataloader patch: %s", e)

If you’d like, I can add a teardown() method to PatchManager that unpatches all applied monkeypatches, and wire it where your lifecycle ends (trainer shutdown).

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between ab4d604 and 311b4c2.

📒 Files selected for processing (4)
  • src/axolotl/common/datasets.py (0 hunks)
  • src/axolotl/loaders/patch_manager.py (2 hunks)
  • src/axolotl/monkeypatch/data/batch_dataset_fetcher.py (3 hunks)
  • tests/test_packed_batch_sampler.py (2 hunks)
💤 Files with no reviewable changes (1)
  • src/axolotl/common/datasets.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/test_packed_batch_sampler.py (1)
src/axolotl/monkeypatch/data/batch_dataset_fetcher.py (2)
  • apply_multipack_dataloader_patch (59-78)
  • remove_multipack_dataloader_patch (81-98)
src/axolotl/loaders/patch_manager.py (2)
src/axolotl/integrations/base.py (2)
  • cfg (352-353)
  • cfg (356-357)
src/axolotl/monkeypatch/data/batch_dataset_fetcher.py (1)
  • apply_multipack_dataloader_patch (59-78)
🔇 Additional comments (4)
src/axolotl/monkeypatch/data/batch_dataset_fetcher.py (1)

53-56: Good: ensure worker processes apply the fetcher patch before running the original loop

Wrapping _worker_loop to call patch_fetchers() in workers aligns with spawn/fork semantics and keeps worker environments consistent with the parent process.

tests/test_packed_batch_sampler.py (1)

51-57: Good: apply the patch explicitly and locally in the test

Importing and applying the dataloader patch inside the test eliminates import-time side effects and clarifies intent.

src/axolotl/loaders/patch_manager.py (2)

64-65: Good: centralize dataloader patching in the pre-model-load phase

Calling _apply_multipack_dataloader_patch() here aligns the patch with other pre-load mutations and gates it behind cfg.sample_packing.


50-71: Patch application occurs before any DataLoader construction

I’ve confirmed that:

  • apply_pre_model_load_patches() is invoked as the very first step in ModelLoader.load() (src/axolotl/loaders/model.py line 168), ensuring all internal monkey-patches (including the multipack and DataLoader patches) are applied upfront.
  • All DataLoader(...) calls appear downstream in the training flow—e.g., in
    • src/axolotl/utils/trainer.py:474
    • src/axolotl/core/trainers/base.py:281, 319
    • src/axolotl/core/trainers/grpo/trainer.py:235

Since the patch manager’s pre-load step always runs before any dataset/DataLoader instantiation, no changes are required here.

Comment thread src/axolotl/monkeypatch/data/batch_dataset_fetcher.py
@codecov

codecov Bot commented Aug 22, 2025

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 95.83333% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
.../axolotl/monkeypatch/data/batch_dataset_fetcher.py 95.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@djsaunde djsaunde merged commit eea7a00 into main Aug 22, 2025
13 of 15 checks passed
@djsaunde djsaunde deleted the explicit-patch branch August 22, 2025 18:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants