make multipack sampler patch explicit#3096
Conversation
📝 WalkthroughWalkthroughAdds a controlled, toggleable monkeypatch for PyTorch DataLoader to handle multipack/sample-packed batch indexes, integrates applying that patch in PatchManager when sample_packing is enabled, removes an import-time side effect from datasets, and updates tests to explicitly apply/remove the patch with guaranteed cleanup. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. 📜 Recent review detailsConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro 💡 Knowledge Base configuration:
You can enable these sources in your CodeRabbit configuration. 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/axolotl/monkeypatch/data/batch_dataset_fetcher.py (1)
20-44: Guard against non-batched indices to prevent crashes when auto_collation=FalseAccessing
possibly_batched_index[0]unconditionally will raise whenauto_collationis False (e.g., DataLoader withbatch_size=None). Since this patch is process-wide, it can break unrelated DataLoaders in the same process. Detect multipack only whenauto_collationis on and the first element is a sequence (list/tuple).Apply this diff to make the fetcher robust and widen support beyond
list:- def fetch(self, possibly_batched_index): - if isinstance(possibly_batched_index[0], list): - # Handle nested structure from MultipackBatchSampler - data = [None for i in possibly_batched_index] - for i, possibly_batched_index_ in enumerate(possibly_batched_index): - if self.auto_collation: - if ( - hasattr(self.dataset, "__getitems__") - and self.dataset.__getitems__ - ): - data[i] = self.dataset.__getitems__(possibly_batched_index_) - else: - data[i] = [self.dataset[idx] for idx in possibly_batched_index_] - else: - data[i] = self.dataset[possibly_batched_index_] - else: - # Standard batch handling - if self.auto_collation: - if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__: - data = self.dataset.__getitems__(possibly_batched_index) - else: - data = [self.dataset[idx] for idx in possibly_batched_index] - else: - data = self.dataset[possibly_batched_index] + def fetch(self, possibly_batched_index): + # Multipack only if auto-collation is on and batch is nested (e.g., [[...], [...], ...]) + is_multipack = ( + self.auto_collation + and isinstance(possibly_batched_index, (list, tuple)) + and len(possibly_batched_index) > 0 + and isinstance(possibly_batched_index[0], (list, tuple)) + ) + + if is_multipack: + # Handle nested structure from MultipackBatchSampler + data = [None] * len(possibly_batched_index) + for i, pack_indices in enumerate(possibly_batched_index): + if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__: + data[i] = self.dataset.__getitems__(pack_indices) + else: + data[i] = [self.dataset[idx] for idx in pack_indices] + else: + # Standard batch handling or single index + if self.auto_collation: + if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__: + data = self.dataset.__getitems__(possibly_batched_index) + else: + data = [self.dataset[idx] for idx in possibly_batched_index] + else: + data = self.dataset[possibly_batched_index] return self.collate_fn(data)Additional nits:
- Replace
for i in possibly_batched_indexwith[None] * len(possibly_batched_index)(applied above).- Prefer
_over the unused loop variable name when appropriate.
🧹 Nitpick comments (5)
src/axolotl/monkeypatch/data/batch_dataset_fetcher.py (2)
59-79: Optional: make patch/unpatch thread-safe and add trace loggingIf multiple threads (or repeated initializations) can touch this patching surface, a simple
RLockplus debug logs will make behavior deterministic and easier to reason about.Language-specific snippet to add at module top (outside the shown ranges):
import threading _LOCK = threading.RLock()Then wrap bodies:
- if _IS_PATCHED: - return + with _LOCK: + if _IS_PATCHED: + return # ... existing body ... - _IS_PATCHED = True + _IS_PATCHED = TrueApply the same pattern in
remove_multipack_dataloader_patch().
81-99: Good: complete restoration of both fetcher references and worker loopRestoring both
_utils.fetchanddataloader._utils.fetchplus the original worker loop prevents sticky global side effects between tests/runs. Consider nulling the stored originals after restoration to discourage accidental reuse and aid GC.if _ORIGINAL_WORKER_LOOP: torch.utils.data._utils.worker._worker_loop = _ORIGINAL_WORKER_LOOP - _IS_PATCHED = False + _IS_PATCHED = False + # Optional: clear originals + # _ORIGINAL_MAP_DATASET_FETCHER = None + # _ORIGINAL_WORKER_LOOP = Nonetests/test_packed_batch_sampler.py (2)
112-113: Verify assertions target shape semantics rather than total elements
numel() <= batch_size * max_seq_lengthcouples correctness to total elements and can be brittle if the collator changes batch dimensionality in the future (e.g., extra packs or ragged stacking). Prefer checking per-dimension constraints.- assert batch["input_ids"].numel() <= batch_size * max_seq_length - assert batch["input_ids"].shape[1] == max_seq_length + # Expect sequence length to be padded/clamped to max_seq_length + assert batch["input_ids"].shape[1] == max_seq_length + # Expect effective batch dimension not to exceed requested batch_size + assert batch["input_ids"].shape[0] <= batch_sizeIf
V2BatchSamplerDataCollatorForSeq2Seqintentionally flattens packs into the batch dimension, the second assertion remains valid; if not, adjust to the appropriate dimension.
110-121: Good: guaranteed cleanup via finally; consider a reusable fixture to DRY this across testsThe
try/finallyensures the global monkeypatch is removed even on failure. If more tests need this patch, a function-scoped fixture reduces repetition and centralizes cleanup.Add a fixture (outside this file’s changed hunk):
import pytest from axolotl.monkeypatch.data.batch_dataset_fetcher import ( apply_multipack_dataloader_patch, remove_multipack_dataloader_patch, ) @pytest.fixture def multipack_dataloader_patch(): apply_multipack_dataloader_patch() try: yield finally: remove_multipack_dataloader_patch()Then use it:
def test_packing(..., multipack_dataloader_patch): ...src/axolotl/loaders/patch_manager.py (1)
281-290: Harden import and add idempotent logging; consider providing a symmetric unpatchImport errors or internal layout changes shouldn’t break model loading. Wrap the import/call in a try/except with a clear log. Also consider exposing a matching teardown path (e.g., a
reset_patches()on PatchManager) that callsremove_multipack_dataloader_patch()to avoid sticky state in long-lived processes.def _apply_multipack_dataloader_patch(self): """Apply multipack dataloader patch if sample packing is enabled.""" if self.cfg.sample_packing: - from axolotl.monkeypatch.data.batch_dataset_fetcher import ( - apply_multipack_dataloader_patch, - ) - - LOG.info("Applying multipack dataloader patch for sample packing...") - apply_multipack_dataloader_patch() + try: + from axolotl.monkeypatch.data.batch_dataset_fetcher import ( + apply_multipack_dataloader_patch, + ) + LOG.info("Applying multipack dataloader patch for sample packing...") + apply_multipack_dataloader_patch() + except Exception as e: + LOG.warning("Failed to apply multipack dataloader patch: %s", e)If you’d like, I can add a
teardown()method toPatchManagerthat unpatches all applied monkeypatches, and wire it where your lifecycle ends (trainer shutdown).
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (4)
src/axolotl/common/datasets.py(0 hunks)src/axolotl/loaders/patch_manager.py(2 hunks)src/axolotl/monkeypatch/data/batch_dataset_fetcher.py(3 hunks)tests/test_packed_batch_sampler.py(2 hunks)
💤 Files with no reviewable changes (1)
- src/axolotl/common/datasets.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/test_packed_batch_sampler.py (1)
src/axolotl/monkeypatch/data/batch_dataset_fetcher.py (2)
apply_multipack_dataloader_patch(59-78)remove_multipack_dataloader_patch(81-98)
src/axolotl/loaders/patch_manager.py (2)
src/axolotl/integrations/base.py (2)
cfg(352-353)cfg(356-357)src/axolotl/monkeypatch/data/batch_dataset_fetcher.py (1)
apply_multipack_dataloader_patch(59-78)
🔇 Additional comments (4)
src/axolotl/monkeypatch/data/batch_dataset_fetcher.py (1)
53-56: Good: ensure worker processes apply the fetcher patch before running the original loopWrapping
_worker_loopto callpatch_fetchers()in workers aligns with spawn/fork semantics and keeps worker environments consistent with the parent process.tests/test_packed_batch_sampler.py (1)
51-57: Good: apply the patch explicitly and locally in the testImporting and applying the dataloader patch inside the test eliminates import-time side effects and clarifies intent.
src/axolotl/loaders/patch_manager.py (2)
64-65: Good: centralize dataloader patching in the pre-model-load phaseCalling
_apply_multipack_dataloader_patch()here aligns the patch with other pre-load mutations and gates it behindcfg.sample_packing.
50-71: Patch application occurs before any DataLoader constructionI’ve confirmed that:
apply_pre_model_load_patches()is invoked as the very first step inModelLoader.load()(src/axolotl/loaders/model.py line 168), ensuring all internal monkey-patches (including the multipack and DataLoader patches) are applied upfront.- All
DataLoader(...)calls appear downstream in the training flow—e.g., in
• src/axolotl/utils/trainer.py:474
• src/axolotl/core/trainers/base.py:281, 319
• src/axolotl/core/trainers/grpo/trainer.py:235Since the patch manager’s pre-load step always runs before any dataset/DataLoader instantiation, no changes are required here.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
Description
Implicit patch was bothering me. Plus, centralized the application to our patch manager.
As an aside, we can probably fix some test isolation issues by removing applied patches after tests execute, like I'm doing here.
Motivation and Context
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit