Skip to content

Streaming SFT support#3101

Merged
djsaunde merged 24 commits into
mainfrom
streaming-v3
Sep 2, 2025
Merged

Streaming SFT support#3101
djsaunde merged 24 commits into
mainfrom
streaming-v3

Conversation

@djsaunde

@djsaunde djsaunde commented Aug 26, 2025

Copy link
Copy Markdown
Collaborator

Description

Streaming datasets should not be limited to pretraining only. This PR changes that.

Also removed --iterable preprocess CLI arg since it was confusing / maybe not advisable to do.

This is a follow-up to #3087 which is smaller in scope since things were broken / getting complicated in that PR.

TODO:

Follow-ups:

Motivation and Context

Some folks have been asking for better dataset streaming support. This is part of a larger effort to refactor / improve the logic around how we handle datasets / dataloading.

How has this been tested?

  • Unit tests
  • Smoke tests
  • Manually (comparing average loss, grad norm, ...)
    • stream vs. no stream, packing
    • stream vs. no stream, no packing

Summary by CodeRabbit

  • New Features

    • Streaming datasets for training: sample packing, attention-isolation controls, automatic streaming when a pretraining dataset is provided, new streaming config flags (streaming, streaming_multipack_buffer_size).
  • Documentation

    • Added “Streaming Datasets” guide and examples/streaming README with pretrain and SFT example configs.
  • Bug Fixes

    • Preserve streaming behavior in fallback loading, safer long-sequence handling, and reduced repeated empty-prompt warnings.
  • Chores

    • Deprecated preprocess --iterable; use streaming: true or --streaming.
  • Tests

    • New unit and end-to-end tests covering streaming behavior.

@djsaunde djsaunde self-assigned this Aug 26, 2025
@coderabbitai

coderabbitai Bot commented Aug 26, 2025

Copy link
Copy Markdown
Contributor

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

Adds streaming-dataset support and docs; deprecates the preprocess-time iterable flag and routes streaming via the train path. Renames pretraining encoding/wrapping APIs to streaming equivalents, removes ConstantLengthDataset, refactors dataset loading/tokenization for streaming-first behavior, updates configs/validation, examples, and tests.

Changes

Cohort / File(s) Summary of changes
Docs & Navigation
_quarto.yml, docs/streaming.qmd
Adds a Streaming Datasets doc and inserts it into Core Concepts navigation.
Examples: Streaming
examples/streaming/README.md, examples/streaming/pretrain.yaml, examples/streaming/sft.yaml
New README and example YAMLs demonstrating streaming pretraining and SFT configs and usage.
CLI: Deprecate preprocess iterable
src/axolotl/cli/args.py, src/axolotl/cli/preprocess.py
Marks --iterable deprecated (default False), updates help text, and makes axolotl preprocess exit early with guidance to use axolotl train --streaming/YAML streaming.
Dataset wrapping & tokenization
src/axolotl/datasets.py, src/axolotl/prompt_tokenizers.py
Removes ConstantLengthDataset, simplifies wrap_dataset_for_tokenized_prompt (maps IterableDatasets), adds type hints, and changes empty-prompt logging to a one-time warning.
Data API rename → streaming
src/axolotl/utils/data/__init__.py, src/axolotl/utils/data/streaming.py
Renames exported functions: encode_pretrainingencode_streaming, wrap_pretraining_datasetwrap_streaming_dataset, adjusts wrapper signatures/logic to derive sizes from cfg and update shuffle/map buffer usage.
SFT / loading refactor for streaming
src/axolotl/utils/data/sft.py, src/axolotl/common/datasets.py
Streaming-first rearchitecture: new _prepare_streaming_dataset/_load_streaming_dataset, remove preprocess_iterable threading, add IterableDatasetDict handling, and route prepare_datasets based on cfg.streaming or pretraining_dataset.
Data helpers & robustness
src/axolotl/utils/data/shared.py, src/axolotl/utils/data/utils.py
Preserve caller-provided streaming flag in local-path fallback; harden long-sequence handling to tolerate missing/falsey column_names and IterableDatasets.
Config & Validation
src/axolotl/utils/schemas/config.py, src/axolotl/utils/schemas/validation.py
Add streaming and streaming_multipack_buffer_size; deprecate pretrain_multipack_buffer_size; add validators to migrate buffer setting and enforce streaming-related constraints (val_set_size, max_steps).
Collators export surface
src/axolotl/utils/collators/__init__.py
Adds explicit __all__ and minor import/docstring formatting adjustments.
Tests: Streaming & packing updates
tests/e2e/test_streaming.py, tests/test_streaming.py, tests/test_packed_pretraining.py, tests/test_packed_dataset.py, tests/test_data.py, tests/e2e/integrations/test_kd.py
Adds unit and e2e streaming tests, updates tests to new streaming APIs, removes ConstantLengthDataset-based test, replaces pretraining wrappers with streaming wrappers, and tweaks KD chat_template to qwen3.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested labels

scheduled_release

Suggested reviewers

  • winglian
  • NanoCode012
  • SalmanMohammadi
✨ Finishing Touches
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch streaming-v3

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@djsaunde djsaunde changed the title Streaming v3 Streaming SFT support Aug 26, 2025
@djsaunde djsaunde requested review from NanoCode012, salmanmohammadi and winglian and removed request for NanoCode012 and winglian August 26, 2025 16:16
@djsaunde djsaunde marked this pull request as ready for review August 26, 2025 16:16
@github-actions

github-actions Bot commented Aug 26, 2025

Copy link
Copy Markdown
Contributor

📖 Documentation Preview: https://68b70d793f71b914b66994cb--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit 4d1a47b

@djsaunde djsaunde requested a review from NanoCode012 August 26, 2025 16:23
@codecov

codecov Bot commented Aug 26, 2025

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 91.95402% with 7 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/axolotl/utils/schemas/validation.py 92.30% 3 Missing ⚠️
src/axolotl/utils/data/utils.py 50.00% 2 Missing ⚠️
src/axolotl/cli/preprocess.py 50.00% 1 Missing ⚠️
src/axolotl/prompt_tokenizers.py 0.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/axolotl/utils/data/pretraining.py (1)

27-37: Guard against missing eos/pad tokens and clarify truncation intent.

Relying on tokenizer.eos_token_id/pad_token_id without checks can explode for tokenizers lacking them (common in base LLaMA variants pre-patch). Also, max_length=max_tokens - 2 assumes we’ll always append 2 tokens per sample; please assert this contract explicitly.

Apply:

 res = tokenizer(
     examples[text_column],
     truncation=True,
     max_length=max_tokens - 2,
     add_special_tokens=True,
 )
+# Validate special tokens (fail fast with clear guidance)
+if tokenizer.eos_token_id is None or tokenizer.pad_token_id is None:
+    raise ValueError(
+        "encode_streaming requires eos_token_id and pad_token_id. "
+        "Ensure tokenizer has these set (e.g., tokenizer.add_special_tokens or config)."
+    )
src/axolotl/utils/data/sft.py (1)

181-194: KeyError risk on skip; also propagate text_column if specified.

When cfg.pretraining_dataset is a list of dicts, skip is optional. Accessing config["skip"] will raise if absent. Use .get with default 0. Including text_column improves parity with the simple path and downstream wrappers.

-        config = cfg.pretraining_dataset[0]
-        return DictDefault(
-            {
-                "path": config["path"],
-                "name": config["name"],
-                "skip": config["skip"],
-                "split": config.get("split", "train"),
-                "data_files": config.get("data_files"),
-                "type": config.get("type", "pretrain"),
-            }
-        )
+        config = cfg.pretraining_dataset[0]
+        return DictDefault(
+            {
+                "path": config["path"],
+                "name": config.get("name"),
+                "skip": config.get("skip", 0),
+                "split": config.get("split", "train"),
+                "data_files": config.get("data_files"),
+                "type": config.get("type", "pretrain"),
+                "text_column": config.get("text_column"),
+            }
+        )
🧹 Nitpick comments (48)
src/axolotl/prompt_tokenizers.py (2)

76-81: Preserve output shape invariants on early-return by including labels=[]

When returning early for empty prompts, the BatchEncoding lacks a "labels" key, whereas the non-empty path always sets it. Some downstream code assumes "labels" exists. Add an empty labels array to the early-return object to keep the contract consistent.

-        empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
+        empty = BatchEncoding(data={"input_ids": [], "attention_mask": [], "labels": []})

76-81: Confirm that LOG.warning_once actually de-duplicates emissions; otherwise this is a no-op rename

Switching to warning_once is good only if it really logs once. The current implementation in axolotl.utils.logging.warning_once appears to just proxy to warning with no caching/dedup, so this won’t reduce log spam. Either (a) implement the dedup logic in the logger, or (b) revert to LOG.warning to avoid misleading semantics.

If you want a minimal implementation in src/axolotl/utils/logging.py:

# Example patch: implement once-only by message text
import logging
from functools import lru_cache

class AxoLogger(logging.Logger):
    @lru_cache(maxsize=4096)
    def _warn_once_cache(self, msg: str):
        return True

    def warning_once(self, msg, *args, **kwargs):
        if self._warn_once_cache(str(msg)):
            super().warning(msg, *args, **kwargs)

Then ensure get_logger returns AxoLogger instances. I can provide a full patch if helpful.

src/axolotl/datasets.py (5)

4-7: Nit: fix grammar in module docstring (“Let’s” vs “Lets”)

Tiny polish for user-facing docs.

-We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
+We want this to be a wrapper for an existing dataset that we have loaded. Let's use the
 concept of middlewares to wrap each dataset. We'll use the collators later on to pad the
 datasets.

28-35: Type tightening: keep_in_memory should be bool, not Optional[bool]

datasets.map expects a bool for keep_in_memory. Annotating as Optional increases the chance of passing None through. Make the type bool and default False.

-        process_count: int | None = None,
-        keep_in_memory: bool | None = False,
+        process_count: int | None = None,
+        keep_in_memory: bool = False,

45-51: Normalize remove_columns input to a list

For consistency with the IterableDataset branch—and to avoid any surprises with dict_keys views—cast features to list before passing to remove_columns.

-        features = dataset.features.keys()
+        features = list(dataset.features.keys())

62-69: Optional: make batch size configurable and consistent across streaming/non-streaming paths

Non-iterable path sets batch_size=1_000 when supports_batched is True; the IterableDataset path leaves it unset (uses library default). Consider threading a common batch_size from the tokenizer or kwargs to keep behavior consistent and tuneable.

Example minimal change (mirrors the non-streaming 1_000 default):

         return dataset.map(
             self.prompt_tokenizer.tokenize_prompt,
             num_proc=self.process_count,
             remove_columns=features,
             keep_in_memory=self.keep_in_memory,
             desc="Tokenizing Prompts",
-            **map_kwargs,
+            **({"batch_size": 1_000} | map_kwargs),
         )

And in wrap_dataset_for_tokenized_prompt (IterableDataset):

-        if prompt_tokenizer.supports_batched:
-            map_kwargs["batched"] = True
+        if prompt_tokenizer.supports_batched:
+            map_kwargs["batched"] = True
+            map_kwargs["batch_size"] = 1_000

If you prefer, expose batch_size via kwargs or a field on PromptTokenizingStrategy.


77-86: IterableDataset.features may be absent/partial; guard remove_columns computation

Some streaming sources don’t define features eagerly. If features is None, list(dataset.features.keys()) will fail. Add a fallback to infer columns from a sample row or skip remove_columns when unknown.

-        features = list(dataset.features.keys())
+        features = list(dataset.features.keys()) if dataset.features is not None else None
         return dataset.map(
             prompt_tokenizer.tokenize_prompt,
-            remove_columns=features,
+            **({} if features is None else {"remove_columns": features}),
             **map_kwargs,
         )

If you want me to add a safe “peek 1 row” helper for IterableDataset to infer columns, I can draft it.

tests/e2e/integrations/test_kd.py (1)

28-28: Switching to qwen3 chat template looks consistent; consider aligning test naming.

The change to chat_template: "qwen3" matches the Qwen3 base model and dataset artifacts used in this test. To reduce confusion, consider renaming test_llama_kd (Line 82) to reflect Qwen3, since the model/template stack is no longer LLaMA-based. Non-blocking.

src/axolotl/utils/data/utils.py (1)

193-207: Make streaming detection explicit and reduce false positives in warnings.

The current guard infers streaming via the presence/absence of column_names, and warns about missing input_ids before checking for a streaming dataset. This can (a) warn in legitimate pre-tokenization/streaming paths and (b) miss the simpler/clearer type check via IterableDataset.

Suggest handling streaming first and softening the message for non-input_ids datasets (e.g., reward modeling, pre-tokenization). This preserves behavior while improving clarity.

Apply this focused refactor:

-    if (
-        hasattr(dataset, "column_names")
-        and dataset.column_names
-        and "input_ids" not in dataset.column_names
-    ):
-        LOG.warning(
-            "Dataset does not contain 'input_ids' column. Skip drop long seq. This is "
-            "expected for reward modeling."
-        )
-        return dataset
-    elif not hasattr(dataset, "column_names") or dataset.column_names is None:
-        LOG.info(
-            "Dataset is streaming (IterableDataset), skipping long sequence handling"
-        )
-        return dataset
+    # Prefer explicit streaming detection first
+    if isinstance(dataset, IterableDataset) and not getattr(dataset, "column_names", None):
+        LOG.info(
+            "Dataset is streaming (IterableDataset) without a column schema; skipping long-sequence handling."
+        )
+        return dataset
+    # If we have a schema but no input_ids, don't attempt long-seq handling
+    if getattr(dataset, "column_names", None) and "input_ids" not in dataset.column_names:
+        LOG.warning(
+            "Dataset does not contain 'input_ids'. Skipping long-sequence handling (common for reward modeling or pre-tokenization stages)."
+        )
+        return dataset
src/axolotl/cli/args.py (1)

17-23: Use a non-optional bool for a deprecated flag.

iterable is now deprecated and has a concrete default. Making it bool (not Optional[bool]) avoids a tri-state and simplifies downstream checks.

Apply this small type tweak:

-    iterable: Optional[bool] = field(
+    iterable: bool = field(
         default=False,
         metadata={
             "help": (
                 "[DEPRECATED] No longer supported. For streaming datasets, use "
                 "'axolotl train' and set 'streaming: true' in your YAML config, or "
                 "pass --streaming instead in the CLI."
             )
         },
     )
src/axolotl/cli/preprocess.py (2)

38-46: Return a non-zero exit code when an unsupported flag is used.

Currently we log an error and return, which yields exit code 0 from the CLI. This can mask misuse in scripts/CI. Prefer raising SystemExit(2) after logging so callers can detect the failure.

Apply this minimal change:

     if cli_args.iterable:
         LOG.error(
             "The --iterable CLI argument for 'axolotl preprocess' is no longer "
             "supported. For training, set 'streaming: true' in your YAML config or "
             "pass '--streaming' in your 'axolotl train' command for on-the-fly "
             "preprocessing."
         )
-        return
+        raise SystemExit(2)

50-52: Minor consistency nit: unify CLI name formatting.

Elsewhere we tend to wrap CLI names with backticks in messages. Consider switching "Run the 'axolotl train' CLI..." to backticks for consistency (or standardize repository-wide).

src/axolotl/utils/schemas/validation.py (2)

343-363: Great deprecation shim; use warning_once to avoid log spam.

The migration from pretrain_multipack_buffer_size to streaming_multipack_buffer_size is handled well. To reduce repeated warnings across multiple validations/parses, prefer LOG.warning_once (already used elsewhere in this module).

Apply this small change:

-            LOG.warning(
+            LOG.warning_once(
                 "`pretrain_multipack_buffer_size` is deprecated. Please use `streaming_multipack_buffer_size` instead."
             )

1098-1126: Validation rules are solid; make val_set_size check robust to 0/0.0.

The truthiness check handles most cases, but being explicit about zero avoids edge cases if a config loader ever passes "0" as a string. Optional nit.

Suggested tweak:

-        if data.get("pretraining_dataset") and data.get("val_set_size"):
+        val = data.get("val_set_size")
+        if data.get("pretraining_dataset") and val not in (None, 0, 0.0, "0", "0.0"):
             raise ValueError(
                 "val_set_size is not supported with pretraining_dataset. "
                 "Use test_datasets to specify evaluation datasets for pretraining."
             )
-        if data.get("streaming") and data.get("val_set_size"):
+        val = data.get("val_set_size")
+        if data.get("streaming") and val not in (None, 0, 0.0, "0", "0.0"):
             raise ValueError(
                 "val_set_size is not supported with streaming datasets. "
                 "Use test_datasets to specify evaluation datasets when streaming is enabled."
             )

And keeping the max_steps requirement for streaming is the right call. LGTM there.

src/axolotl/utils/data/pretraining.py (8)

37-42: Non-concatenate path likely needs label masking for PAD (and possibly BOS) consistency.

When concatenate=False, labels mirror input_ids (no -100 masking for pad). This diverges from the concatenated path and can hurt loss. Consider masking padding (and optionally BOS depending on training objective).

Example:

 if not concatenate:
-    return {
-        "input_ids": [seq.tolist() for seq in input_ids],
-        "labels": [seq.tolist() for seq in targets],
-        "attention_mask": [seq.tolist() for seq in attention_mask],
-    }
+    masked_labels = []
+    for ids, mask in zip(input_ids, attention_mask):
+        # ignore PAD tokens in loss
+        lbl = ids.clone()
+        lbl[mask == 0] = -100
+        masked_labels.append(lbl)
+    return {
+        "input_ids": [seq.tolist() for seq in input_ids],
+        "labels": [seq.tolist() for seq in masked_labels],
+        "attention_mask": [seq.tolist() for seq in attention_mask],
+    }

131-165: Pad leftovers without a while loop.

The while loop always executes once (padding to exact max_tokens in a single cat). This can be simplified and made faster.

-    if buffer_input_ids.numel() > 0:  # for any leftover tokens
-        while buffer_input_ids.numel() < max_tokens:  # make all sequences equal in size
-            buffer_input_ids = torch.cat( ... )
-            buffer_labels = torch.cat( ... )
-            buffer_attention_mask = torch.cat( ... )
+    if buffer_input_ids.numel() > 0:  # for any leftover tokens
+        pad_len = max_tokens - buffer_input_ids.numel()
+        if pad_len > 0:
+            buffer_input_ids = torch.cat(
+                (buffer_input_ids, torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)), dim=0
+            )
+            buffer_labels = torch.cat(
+                (buffer_labels, torch.full((pad_len,), -100, dtype=torch.long)), dim=0
+            )
+            buffer_attention_mask = torch.cat(
+                (buffer_attention_mask, torch.zeros((pad_len,), dtype=torch.long)), dim=0
+            )

175-176: Improve debug log context.

LOG.debug(len(...)) without a message makes tracing logs painful.

-LOG.debug(len(ret["input_ids"]))
+LOG.debug("encode_streaming produced %d sequences of length %d", len(ret["input_ids"]), max_tokens)

208-210: Avoid mutating cfg.micro_batch_size as a side effect.

Overwriting cfg at runtime can surprise downstream components, logs, or callbacks that rely on the configured value. Since you already captured the original micro_batch_size in the functools.partial, prefer using a local override for the data loader only.

-# Set this to 1 so downstream data_loader doesn't try to increase the batch size again
-cfg.micro_batch_size = 1
+# Avoid mutating cfg; keep the actual data_loader batch size at 1 via the collator/sampler path
+# (If a consumer depends on cfg.micro_batch_size later, consider introducing cfg._effective_loader_batch_size)

If mutation is required for compatibility, at least stash-and-restore around dataset construction.


220-224: Shuffle buffer sized to streaming_multipack_buffer_size — OK but document memory trade-offs.

Using a large shuffle buffer here can spike RAM. Consider logging the effective buffer size on startup for operator awareness.

-    if cfg.shuffle_merged_datasets:
+    if cfg.shuffle_merged_datasets:
+        LOG.info("Shuffling streaming dataset with buffer_size=%d and seed=%s",
+                 cfg.streaming_multipack_buffer_size, cfg.seed)
         dataset = dataset.shuffle(
             seed=cfg.seed, buffer_size=cfg.streaming_multipack_buffer_size
         )

241-243: Batch size for map equals streaming_multipack_buffer_size — sanity-check defaults.

This couples mapping batch size to the packing buffer. Reasonable, but consider capping to a safe upper bound (e.g., min(buffer_size, 8192)) or exposing a separate encode_batch_size for fine control.

No code diff necessary if you prefer current behavior; a config knob would suffice.


282-285: Duplicate deletion of num_truncated_tokens.

The key removal appears twice.

-            if "num_truncated_tokens" in features:
-                del features["num_truncated_tokens"]
             if "num_truncated_tokens" in features:
                 del features["num_truncated_tokens"]

277-297: Return python lists for map-compatibility and parity with encode_streaming.

encode_streaming returns lists of ints; encode_packed_streaming returns lists of Tensors. HF streaming map often tolerates tensors, but returning native lists avoids surprises and aligns keys.

-                chunked_data[feature].append(collated_features[feature].squeeze(0))
+                chunked_data[feature].append(
+                    collated_features[feature].squeeze(0).tolist()
+                )

Also consider sorting keys to ensure deterministic column order.

tests/test_data.py (3)

14-17: Rename test class/docstring for clarity.

Now that the subject is encode_streaming, the class/docstring name “EncodePretraining” is misleading.

-class TestEncodePretraining(unittest.TestCase):
-    """
-    test class for encode pretraining and md5 helper
-    """
+class TestEncodeStreaming(unittest.TestCase):
+    """Tests for encode_streaming and the md5 helper."""

32-59: Add a test for concatenate=False and enforce no double-EOS.

The current test only covers the concatenated path. Add a case for concatenate=False and one that checks we don’t produce double EOS when the tokenizer already adds EOS.

Happy to draft the additional tests if you share your preferred tokenizer fixture constraints.


21-29: Offline tokenizer setup is good; consider pinning a tiny tokenizer for speed.

huggyllama/llama-7b is large; even offline, tokenizer files are heavier than needed. A tiny tokenizer (e.g., a small BPE from a toy model) would speed up tests.

If swapping is non-trivial, keep as is.

examples/streaming/README.md (4)

3-4: Tighten wording.

Minor grammar/style polish.

-This directory contains example configurations for using Axolotl's streaming dataset
-functionality, which enables memory-efficient training with large datasets.
+This directory contains example configurations for Axolotl’s streaming datasets,
+which enable memory‑efficient training on large datasets.

10-16: Clarify model/dataset phrasing and bullets.

-Demonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset
-with SmolLM2-135M.
+Demonstrates a streaming configuration for pretraining on FineWeb‑Edu
+using the SmolLM2‑135M base model.
 
-- Uses `pretraining_dataset` configuration for automatic streaming
-- Multipack attention control to prevent cross-attention between packed sequences
-- Buffer size configuration for memory management
+- Automatic streaming via `pretraining_dataset`
+- Multipack attention to prevent cross‑attention between packed sequences
+- Tune buffer size for memory/throughput trade‑offs

21-24: Clarify SFT bullets.

-- Explicit `streaming: true` flag for SFT datasets
-- Evaluation dataset handling with streaming
-- Memory-efficient training on instruction datasets
+- Enable with `streaming: true` for SFT datasets
+- Evaluation datasets are loaded normally (not streamed)
+- Memory‑efficient SFT on instruction datasets

60-62: Add a tip about local HF cache.

Mention HF_DATASETS_CACHE to minimize network bottlenecks.

-- Download small / frequently-used datasets locally for better performance
+- Cache datasets locally (HF_DATASETS_CACHE) for better performance
docs/streaming.qmd (4)

30-43: Pretraining auto-streaming section reads well. Add a one-liner about “no preprocess needed.”

Matches the PR narrative and examples; adding an explicit note will reduce questions.

 For pretraining tasks, streaming is automatically enabled when using
 `pretraining_dataset`:
 
 ```yaml
@@
 sample_packing: true

+You don’t need to run axolotl preprocess when using streaming pretraining.


---

`83-86`: **Clarify multipack attention semantics between SFT and pretraining.**

Make explicit that SFT isolates attention automatically; pretraining uses pretrain_multipack_attn.


```diff
-# For SFT: attention is automatically isolated between packed samples
-# For pretraining: control with pretrain_multipack_attn
+# For SFT: attention is automatically isolated between packed samples (always enabled)
+# For pretraining: enable isolation via pretrain_multipack_attn

106-110: Evaluation not streamed — good. Add a pointer to why (determinism).

-Evaluation datasets are not streamed to ensure consistent evaluation metrics. They're
-loaded normally even when training uses streaming.
+Evaluation datasets are not streamed to ensure deterministic, comparable metrics.
+They’re loaded normally even when training uses streaming.

61-73: Consider warning callouts for memory knobs.

This section is great; a short caution callout would stand out in rendered docs.

Add:

::: {.callout-warning}
Large `streaming_multipack_buffer_size` and dataset shuffling can significantly increase RAM usage.
Start smaller and scale up based on headroom.
:::
tests/test_packed_pretraining.py (1)

79-86: Future-proof the test against cfg.micro_batch_size mutation inside wrap_streaming_dataset.

This test correctly captures original_bsz before calling wrap_streaming_dataset, relying on the fact that wrap_streaming_dataset sets cfg.micro_batch_size = 1 only after capturing the prior value into encode’s batch_size. To make the intent explicit and guard against future refactors that might change that order, add a quick assertion after the call.

Apply this diff:

         original_bsz = cfg.micro_batch_size
         train_dataset = wrap_streaming_dataset(
             dataset,
             tokenizer_huggyllama,
             cfg,
             ds_wrapper_partial,
         )
+        # wrap_streaming_dataset flattens packed samples into a single item and sets
+        # micro_batch_size to 1 for the downstream DataLoader. Keep this invariant visible.
+        assert cfg.micro_batch_size == 1
examples/streaming/sft.yaml (4)

10-13: Avoid underscores in numeric literals in YAML for maximum loader compatibility.

Some YAML parsers (and older loader stacks) don’t accept underscore-separated numerals. Switching 10_000 → 10000 avoids parsing ambiguity.

-streaming_multipack_buffer_size: 10_000
+streaming_multipack_buffer_size: 10000

24-26: Clarify packing expectations with micro_batch_size=1.

With sample_packing: true and micro_batch_size: 1, packing still works (it packs multiple short samples into a single sequence), but users sometimes expect “flattened” mega-batches from micro_batch_size>1. A brief comment in this example could reduce confusion.

 sample_packing: true
 flash_attention: true
+# micro_batch_size controls per-step item count; packing still groups multiple short samples
+# into each sequence even when this is 1.
 micro_batch_size: 1

Also applies to: 19-22


45-50: Explicitly set WandB fields to null or remove them to avoid accidental init via environment.

Empty keys are interpreted as null by most loaders, but being explicit (or omitting them) prevents accidental WandB initialization when environment variables are present.

-# Weights & Biases (optional)
-wandb_project:
-wandb_entity:
-wandb_watch:
-wandb_name:
-wandb_log_model:
+# Weights & Biases (optional) — leave commented out or set explicitly to null when needed
+# wandb_project: null
+# wandb_entity: null
+# wandb_watch: null
+# wandb_name: null
+# wandb_log_model: null

19-22: Hardware-sensitive knobs in an example config.

flash_attention: true and tf32: true are great defaults on recent NVIDIA GPUs, but they can surprise users on unsupported hardware. A one-line comment helps set expectations.

 flash_attention: true
+## Requires compatible GPU/software stack; set to false if you see kernel errors.
 ...
 tf32: true
+## Effective on Ampere+ GPUs; harmless on others but can be disabled if desired.

Also applies to: 35-37

tests/e2e/test_streaming.py (2)

31-33: pretrain_multipack_attn is unused in SFT configs.

In the streaming SFT path (datasets provided; no pretraining_dataset), wrap_streaming_dataset forces multipack_attn=True and ignores pretrain_multipack_attn. Dropping it reduces confusion.

-                "pretrain_multipack_attn": sample_packing,
                 "streaming_multipack_buffer_size": 10000,
                 "dataset_processes": 1,

51-53: Consolidate Optimizer Configuration Across Tests

Given that "adamw_torch_fused" is used in dozens of test files (over 90 occurrences found), updating each instance individually is error-prone and hard to maintain. Instead, centralize the optimizer choice and introduce a fallback for environments where the fused variant isn’t available. For example:

  • Define a helper in your test utilities (e.g. tests/utils/optimizers.py):

    def get_default_optimizer_name():
        try:
            # this will raise or be unavailable on CPU-only builds
            import torch
            torch._C._jit_get_executor_state()  # dummy check for fused support
            return "adamw_torch_fused"
        except Exception:
            return "adamw_torch"
  • In your tests or builders, replace literal strings with:

    from tests.utils.optimizers import get_default_optimizer_name
    ...
        "optimizer": get_default_optimizer_name(),

This approach:

  • Ensures CI portability without touching each test.
  • Keeps performance benefits on CUDA-enabled runners.
  • Simplifies future changes to optimizer defaults.

Locations requiring this refactor (non-exhaustive):

• tests/test_packed_dataset.py:66
• tests/e2e/test_profiler.py:39
• tests/e2e/test_streaming.py:51
• tests/core/test_builders.py:39
• …and all other instances of "optimizer": "adamw_torch_fused" under tests/

src/axolotl/utils/data/__init__.py (2)

4-6: Provide backward-compat aliases for renamed APIs to ease downstream migration.

This rename is reasonable, but external users may still import encode_pretraining/wrap_pretraining_dataset. Keep soft aliases with a deprecation warning for a release or two.

-from axolotl.utils.data.pretraining import (
-    encode_streaming,
-    wrap_streaming_dataset,
-)
+import warnings
+from axolotl.utils.data.pretraining import encode_streaming, wrap_streaming_dataset
+
+# Backward-compat (deprecated): keep old names alive with warnings
+def encode_pretraining(*args, **kwargs):
+    warnings.warn(
+        "encode_pretraining is deprecated; use encode_streaming instead.",
+        DeprecationWarning,
+        stacklevel=2,
+    )
+    return encode_streaming(*args, **kwargs)
+
+def wrap_pretraining_dataset(*args, **kwargs):
+    warnings.warn(
+        "wrap_pretraining_dataset is deprecated; use wrap_streaming_dataset instead.",
+        DeprecationWarning,
+        stacklevel=2,
+    )
+    return wrap_streaming_dataset(*args, **kwargs)

14-21: Add deprecated aliases to __all__ for backward compatibility

I ran the following search to ensure there are no lingering references to the old function names outside of the shim in __init__.py:

rg -nP '\b(wrap_pretraining_dataset|encode_pretraining)\b' -g '!src/axolotl/utils/data/__init__.py'

No matches were found, confirming it’s safe to introduce these deprecated exports without breaking internal imports.

Apply this optional refactor to include the deprecated names:

 __all__ = [
     "encode_streaming",
     "wrap_streaming_dataset",
+    # Deprecated exports (remove in a future major)
+    "encode_pretraining",
+    "wrap_pretraining_dataset",
     "prepare_preference_datasets",
     "get_dataset_wrapper",
     "prepare_datasets",
     "md5",
 ]
src/axolotl/utils/schemas/config.py (2)

492-499: Mark pretrain_multipack_buffer_size as deprecated in the schema and auto-migrate to streaming_multipack_buffer_size.

The comment notes deprecation, but adding formal deprecation metadata plus an auto-migration guard reduces user friction and improves generated schemas.

-    # Deprecated: Use streaming_multipack_buffer_size instead
-    pretrain_multipack_buffer_size: int | None = None
+    # Deprecated: Use streaming_multipack_buffer_size instead
+    pretrain_multipack_buffer_size: int | None = Field(
+        default=None,
+        deprecated="Use streaming_multipack_buffer_size instead",
+        json_schema_extra={
+            "description": "Deprecated: Use streaming_multipack_buffer_size instead"
+        },
+    )

Additionally, add a before-validator to auto-migrate the value when provided:

@@
 class AxolotlInputConfig(
@@
 ):
@@
     model_config = {"populate_by_name": True}
@@
+    @model_validator(mode="before")
+    @classmethod
+    def migrate_pretrain_buffer_size(cls, data):
+        # Migrate deprecated field when the new field is unset
+        if (
+            data.get("pretrain_multipack_buffer_size") is not None
+            and data.get("streaming_multipack_buffer_size") is None
+        ):
+            data["streaming_multipack_buffer_size"] = data[
+                "pretrain_multipack_buffer_size"
+            ]
+            LOG.warning(
+                "pretrain_multipack_buffer_size is deprecated; "
+                "using its value for streaming_multipack_buffer_size."
+            )
+        return data

507-509: Guardrails for streaming_multipack_buffer_size.

Consider validating that the buffer is positive and not absurdly large relative to available RAM to prevent OOMs during dataset.map on big corpora.

Option A (schema-level constraint): add ge=1 and a brief description.

-    streaming_multipack_buffer_size: int | None = 10_000
+    streaming_multipack_buffer_size: int | None = Field(
+        default=10_000,
+        ge=1,
+        json_schema_extra={"description": "Batch size for streaming map/packing buffer"}
+    )

Option B (validator): cap or warn when > 1e6.

examples/streaming/pretrain.yaml (2)

15-18: Consider setting a deterministic seed.

Adding seed makes shuffling and sampling reproducible across runs.

 # Training configuration
+seed: 42
 max_steps: 1000
 output_dir: ./outputs/smollm2-135m-pretrain-streaming

40-45: Optional: add eval to monitor training.

If you want on-the-fly validation in this example, add a small test_datasets block and set eval_sample_packing accordingly.

tests/test_streaming.py (1)

34-37: Make log assertion resilient to message ordering.

Indexing cm.output[0] is brittle. Assert against any captured warning.

-        with self.assertLogs("axolotl.utils.schemas.validation", level="WARNING") as cm:
-            validated_cfg = validate_config(cfg_old)
-            self.assertIn("pretrain_multipack_buffer_size` is deprecated", cm.output[0])
+        with self.assertLogs("axolotl.utils.schemas.validation", level="WARNING") as cm:
+            validated_cfg = validate_config(cfg_old)
+        self.assertTrue(
+            any("pretrain_multipack_buffer_size` is deprecated" in msg for msg in cm.output),
+            "Expected deprecation warning for pretrain_multipack_buffer_size"
+        )
src/axolotl/utils/data/sft.py (1)

208-211: Nit: variable naming.

pretraining_config now covers SFT-with-streaming too. Consider renaming to dataset_config for clarity (no functional change).

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between e1131e9 and 6de16b6.

📒 Files selected for processing (23)
  • _quarto.yml (1 hunks)
  • docs/streaming.qmd (1 hunks)
  • examples/streaming/README.md (1 hunks)
  • examples/streaming/pretrain.yaml (1 hunks)
  • examples/streaming/sft.yaml (1 hunks)
  • src/axolotl/cli/args.py (1 hunks)
  • src/axolotl/cli/preprocess.py (1 hunks)
  • src/axolotl/common/datasets.py (0 hunks)
  • src/axolotl/datasets.py (1 hunks)
  • src/axolotl/prompt_tokenizers.py (1 hunks)
  • src/axolotl/utils/data/__init__.py (2 hunks)
  • src/axolotl/utils/data/pretraining.py (3 hunks)
  • src/axolotl/utils/data/sft.py (8 hunks)
  • src/axolotl/utils/data/shared.py (0 hunks)
  • src/axolotl/utils/data/utils.py (1 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (2 hunks)
  • tests/e2e/integrations/test_kd.py (1 hunks)
  • tests/e2e/test_streaming.py (1 hunks)
  • tests/test_data.py (2 hunks)
  • tests/test_packed_dataset.py (0 hunks)
  • tests/test_packed_pretraining.py (2 hunks)
  • tests/test_streaming.py (1 hunks)
💤 Files with no reviewable changes (3)
  • src/axolotl/common/datasets.py
  • src/axolotl/utils/data/shared.py
  • tests/test_packed_dataset.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-08T07:22:40.131Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3038
File: examples/slurm/axolotl.slurm:16-16
Timestamp: 2025-08-08T07:22:40.131Z
Learning: In Axolotl (PR #3038), the preprocess codepath sets AXOLOTL_IS_PREPROCESS internally, so external scripts (e.g., examples/slurm/axolotl.slurm) need not export it for the early-return in src/axolotl/utils/data/sft.py to trigger.

Applied to files:

  • src/axolotl/cli/preprocess.py
🧬 Code graph analysis (8)
tests/test_streaming.py (4)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/utils/data/sft.py (2)
  • _prepare_streaming_dataset (125-176)
  • prepare_datasets (48-65)
src/axolotl/utils/config/__init__.py (1)
  • validate_config (259-303)
src/axolotl/utils/data/pretraining.py (1)
  • wrap_streaming_dataset (179-244)
src/axolotl/prompt_tokenizers.py (1)
src/axolotl/utils/logging.py (1)
  • warning_once (31-39)
tests/test_data.py (3)
src/axolotl/utils/data/pretraining.py (1)
  • encode_streaming (20-176)
src/axolotl/utils/data/utils.py (1)
  • md5 (73-78)
tests/test_datasets.py (1)
  • tokenizer (31-35)
tests/e2e/test_streaming.py (3)
src/axolotl/utils/config/__init__.py (2)
  • normalize_config (97-237)
  • validate_config (259-303)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
tests/e2e/utils.py (2)
  • check_model_output_exists (162-183)
  • check_tensorboard (140-159)
tests/test_packed_pretraining.py (2)
src/axolotl/utils/data/wrappers.py (1)
  • get_dataset_wrapper (57-131)
src/axolotl/utils/data/pretraining.py (1)
  • wrap_streaming_dataset (179-244)
src/axolotl/utils/data/__init__.py (1)
src/axolotl/utils/data/pretraining.py (2)
  • encode_streaming (20-176)
  • wrap_streaming_dataset (179-244)
src/axolotl/utils/data/pretraining.py (1)
src/axolotl/utils/collators/batching.py (1)
  • PretrainingBatchSamplerDataCollatorForSeq2Seq (200-227)
src/axolotl/utils/data/sft.py (3)
src/axolotl/utils/data/lock.py (1)
  • FileLockLoader (17-68)
src/axolotl/utils/data/pretraining.py (1)
  • wrap_streaming_dataset (179-244)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
🪛 LanguageTool
examples/streaming/README.md

[grammar] ~3-~3: There might be a mistake here.
Context: ...ns for using Axolotl's streaming dataset functionality, which enables memory-effi...

(QB_NEW_EN)


[grammar] ~10-~10: There might be a mistake here.
Context: ...ning tasks using the fineweb-edu dataset with SmolLM2-135M. - Uses `pretraining_...

(QB_NEW_EN)


[grammar] ~21-~21: There might be a mistake here.
Context: ... streaming: true flag for SFT datasets - Evaluation dataset handling with streami...

(QB_NEW_EN)


[grammar] ~22-~22: There might be a mistake here.
Context: ...aluation dataset handling with streaming - Memory-efficient training on instruction...

(QB_NEW_EN)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: preview
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
🔇 Additional comments (24)
src/axolotl/utils/data/pretraining.py (6)

49-63: Potential double-EOS depending on tokenizer behavior.

Some tokenizers add EOS when add_special_tokens=True; you then append another EOS here. That can insert double EOS boundaries. If intentional, fine; otherwise gate on whether the last token is already EOS.

Suggested defensive check:

-        input_ids[i] = torch.cat(
+        # Avoid double-EOS if tokenizer already appended it
+        eos_pad = [tokenizer.eos_token_id, tokenizer.pad_token_id]
+        if len(input_ids[i]) > 0 and input_ids[i][-1].item() == tokenizer.eos_token_id:
+            eos_pad = [tokenizer.pad_token_id]  # only pad if EOS present
+        input_ids[i] = torch.cat(
             (
                 input_ids[i],
-                torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
+                torch.tensor(eos_pad),
             ),
             dim=0,
         )
-        targets[i] = torch.cat(
+        tgt_tail = [tokenizer.eos_token_id, -100] if len(eos_pad) == 2 else [-100]
+        targets[i] = torch.cat(
             (
                 targets[i],
-                torch.tensor([tokenizer.eos_token_id, -100]),
+                torch.tensor(tgt_tail),
             ),
             dim=0,
         )

186-198: Multipack attention policy: SFT always True — good; pretraining uses cfg.pretrain_multipack_attn.

This is consistent with isolating attention for SFT. No action needed.


231-235: Peeking a streaming dataset to infer remove_columns can consume one sample.

Iterating a streaming dataset here risks dropping the first example for subsequent consumers in some dataset implementations.

Prefer dataset.features if available, or use dataset.take(1) (supported by HF streaming) to avoid side effects:

-    if dataset.features is None:
-        for first_row in dataset:
-            remove_columns = list(first_row.keys())
-            break
+    if dataset.features is None:
+        try:
+            first = next(iter(dataset.take(1)))
+        except AttributeError:
+            first = next(iter(dataset))
+        remove_columns = list(first.keys())

Please test with an actual IterableDataset to confirm no sample loss.


259-266: Attention-mask handling during packing: confirm collator/trainer compatibility.

With drop_attention_mask=multipack_attn, features may lack "attention_mask", so the custom collator’s per-pack attention segmentation won’t apply. That matches the FIXME comment (rely on position_ids workaround), but it means multipack_attn in collator is effectively ignored when True.

  • Confirm the trainer path correctly enables the position-ids isolation when attention masks are dropped.
  • Consider consolidating the logic to avoid conflicting signals (e.g., pass multipack_attn only to the trainer path when dropping masks).

268-275: Sampler params look sensible for streaming-packed flow.

batch_max_len=batch_size * max_seq_length with batch_size=1 leverages the multipack sampler as intended.


70-70: Confirmed Python Floor ≥3.10 — no changes needed

The project’s pyproject.toml declares requires-python = ">=3.10", so using zip(..., strict=False) (which requires Python 3.10+) is safe.

_quarto.yml (1)

275-276: Good addition to Core Concepts navigation.

docs/streaming.qmd placed right after dataset_preprocessing makes sense.

tests/test_data.py (1)

9-10: API import changes look correct.

Switch to encode_streaming and keeping md5 import aligned with init exports is fine.

examples/streaming/README.md (1)

45-49: Confirm YAML numeric underscores are accepted by your loader.

Some YAML parsers don’t accept 10_000 as a numeric literal. If your config loader handles it, great; otherwise quote it or use 10000.

If needed:

-streaming_multipack_buffer_size: 10_000
+streaming_multipack_buffer_size: 10000
docs/streaming.qmd (1)

88-89: Link target looks correct for same-directory doc; consider adding sidebar link text for clarity.

No code change needed; just confirming relative link works given both are under docs/.

tests/test_packed_pretraining.py (1)

12-12: Import rename to streaming API looks correct.

The switch to import wrap_streaming_dataset aligns with the public API change and keeps this test on the streaming path. No issues.

tests/e2e/test_streaming.py (2)

18-21: Good coverage: parameterizing packing on/off.

Exercising both sample_packing=True/False in the same test keeps the streaming path honest across packer modes.


60-66: End-to-end flow and postconditions look solid.

validate_config + normalize_config before load_datasets/train, followed by artifact checks and a basic loss sanity check via TensorBoard is a pragmatic E2E signal.

src/axolotl/utils/schemas/config.py (2)

500-506: Confirmed: pretraining_sample_concatenation is still referenced correctly, no legacy aliases detected

All occurrences of the relocated field use the exact same name, so the move won’t affect runtime behavior:

  • Declaration remains at src/axolotl/utils/schemas/config.py:500
  • Used in data loader: src/axolotl/utils/data/pretraining.py:217
  • Used in model builder: src/axolotl/core/builders/causal.py:428

No references to any outdated or aliased field name were found.


507-507: No changes needed: streaming defaults to False end-to-end

I’ve traced all the key code paths:

  • In the schema, streaming: bool | None = None allows users to omit the flag.
  • In data loading (utils/data/shared.py), the function signature is
    def load_dataset(..., streaming: bool = False, ...):
        ...
        return {"streaming": streaming, …}
    so if the config value is None, the default False is used.
  • Callers either pass an explicit boolean or rely on that default:
    • RL pipeline in utils/data/rl.py always calls with streaming=False, regardless of config.
    • SFT pipeline in utils/data/sft.py gates on cfg.streaming (where None is falsy) and only requests streaming when explicitly set.
    • Validation in utils/schemas/validation.py checks data.get("streaming") (where NoneFalse) before applying streaming‐only rules.

All entry points thus fall back deterministically to non‐streaming when streaming is unset. No further changes are required here.

examples/streaming/pretrain.yaml (1)

36-38: Double-check that bf16: auto is accepted by your config/Trainer path.

Some stacks expect bf16 to be a boolean and infer automatically based on hardware without the "auto" sentinel. If your schema explicitly supports "auto", ignore this.

tests/test_streaming.py (3)

143-165: Potential fragility: test depends on dict-like cfg.datasets[0] being attribute-accessible.

Current implementation of _prepare_streaming_dataset (see sft.py Lines 141–147) reads cfg.datasets[0].type. If cfg.datasets[0] is a plain dict (as in this test), that attribute access would raise. I’m proposing a fix in sft.py to robustly handle both dicts and model instances. If you don’t adopt that fix, consider wrapping the config via validate_config here before calling _prepare_streaming_dataset.


166-200: Good check: SFT forces multipack_attn=True.

This verifies our intended invariant for SFT packing. Looks solid.


201-235: Good check: pretraining respects pretrain_multipack_attn.

This keeps flexibility for pretraining while enforcing attention isolation for SFT.

src/axolotl/utils/data/sft.py (5)

63-65: Routing looks correct.

Delegating to streaming path when either cfg.streaming or cfg.pretraining_dataset is set is consistent with the new model.


171-176: Nice: clear messaging and step computation for streaming.

Explicitly logging unavailability of dedup for streaming and returning max_steps or -1 is reasonable.


243-249: Wrapper switch to wrap_streaming_dataset looks good.

This aligns with the streaming-first approach and matches the collator/encoding flow used elsewhere.


385-394: Good: IterableDatasetDict support on split selection.

This unifies handling for both in-memory and streaming dict datasets.


349-356: Correct: skip long-seq processing for iterable preprocessing.

Avoiding handle_long_seq_in_dataset when preprocess_iterable=True prevents loading the entire stream into memory.

Comment thread examples/streaming/pretrain.yaml
Comment thread src/axolotl/utils/data/pretraining.py
Comment thread src/axolotl/utils/data/sft.py Outdated

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (6)
examples/streaming/README.md (6)

8-9: Clarify run instructions, add both example commands, and note deprecation.

Improves readability, avoids the semicolon splice, and surfaces the --iterable deprecation where users will see it.

-Run the following examples with e.g. `axolotl train examples/streaming/sft.yaml`; no
-`axolotl preprocess` required!
+Run either of the following (no axolotl preprocess step required):
+
+```bash
+axolotl train examples/streaming/pretrain.yaml
+axolotl train examples/streaming/sft.yaml
+```
+
+Note: The `--iterable` flag in `axolotl preprocess` is deprecated; use `streaming: true` (for SFT) or `pretraining_dataset` (for pretraining) instead.

13-15: Be explicit about dataset slugs and config knobs (pretrain_multipack_attn, buffer).

Names the knobs users will look for and aligns phrasing with the examples/config.

-Demonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset
-with SmolLM2-135M.
+Demonstrates streaming configuration for pretraining using HuggingFaceFW/fineweb-edu
+with SmolLM2-135M.
-- Uses `pretraining_dataset` configuration for automatic streaming
-- Multipack attention control to prevent cross-attention between packed sequences
-- Buffer size configuration for memory management
+- Uses `pretraining_dataset` for automatic streaming
+- Enable `pretrain_multipack_attn` to prevent cross-sample attention when packing
+- Tune buffer size (`streaming_multipack_buffer_size`) for memory/packing trade-offs

Also applies to: 16-18


35-37: Avoid hard-coding the default buffer size unless guaranteed.

If the code’s default changes, this doc goes stale. Either remove the default or verify it matches the implementation.

-- Controls buffer size for sample packing (default: 10,000)
+- Controls buffer size for sample packing

43-46: Tighten grammar for sample_packing bullets.

-- Packs multiple samples into single sequences
-- Minimize per-step padding tokens
+- Packs multiple samples into a single sequence
+- Minimizes per-step padding tokens

47-50: Add actionable performance tips for low-memory environments.

Surface concrete mitigations users will reach for when hitting OOM or slowness.

 - Download small / frequently-used datasets locally for better performance
 - Larger buffer sizes improve packing efficiency
+- If memory is constrained or you hit OOM, reduce `streaming_multipack_buffer_size` or disable `shuffle_merged_datasets`
+- Remote datasets may benefit from local caching to avoid network bottlenecks

6-7: Cross-reference the detailed docs page.

Helps users discover the full guidance added in this PR.

 ## Examples
+
+For a deeper dive, see `docs/streaming.qmd` in the repository.
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between d413b04 and 8492bd1.

📒 Files selected for processing (1)
  • examples/streaming/README.md (1 hunks)
🧰 Additional context used
🪛 LanguageTool
examples/streaming/README.md

[grammar] ~3-~3: There might be a mistake here.
Context: ...ns for using Axolotl's streaming dataset functionality, which enables memory-effi...

(QB_NEW_EN)


[grammar] ~8-~8: There might be a mistake here.
Context: ...l train examples/streaming/sft.yaml; no axolotl preprocess` required! ### Pret...

(QB_NEW_EN)


[grammar] ~13-~13: There might be a mistake here.
Context: ...ning tasks using the fineweb-edu dataset with SmolLM2-135M. - Uses `pretraining_...

(QB_NEW_EN)


[grammar] ~24-~24: There might be a mistake here.
Context: ... streaming: true flag for SFT datasets - Evaluation dataset handling with streami...

(QB_NEW_EN)


[grammar] ~25-~25: There might be a mistake here.
Context: ...aluation dataset handling with streaming - Memory-efficient training on instruction...

(QB_NEW_EN)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: preview

Comment thread examples/streaming/README.md

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
examples/streaming/README.md (1)

51-51: Fix truncated line at end of file

There’s a stray line “51” that will render incorrectly. Remove it to avoid a formatting artifact.

-51
+
src/axolotl/utils/data/sft.py (2)

212-219: Preserve prompt style by parsing dataset type before partializing get_dataset_wrapper.

get_dataset_wrapper in the non-streaming path receives both base_type and prompt_style (parsed via _parse_dataset_type). Here only dataset_base_type is provided; prompt style will be lost for types like "sft:chatml".

Apply:

-    dataset_wrapper_partial = functools.partial(
-        get_dataset_wrapper,
-        dataset_config=pretraining_config,
-        tokenizer=tokenizer,
-        cfg=cfg,
-        dataset_base_type=pretraining_config["type"],
-    )
+    d_base_type, d_prompt_style = _parse_dataset_type(pretraining_config.get("type"))
+    dataset_wrapper_partial = functools.partial(
+        get_dataset_wrapper,
+        dataset_config=pretraining_config,
+        tokenizer=tokenizer,
+        cfg=cfg,
+        dataset_base_type=d_base_type,
+        dataset_prompt_style=d_prompt_style,
+    )

228-235: Bug: streaming loader bypasses auth and shared loader logic.

Directly calling datasets.load_dataset omits cfg.hf_use_auth_token (or token) and any central config logic in load_dataset_with_config. Private datasets will fail.

Apply:

-        iter_dataset = load_dataset(
-            pretraining_config["path"],
-            streaming=True,
-            split=pretraining_config["split"],
-            name=pretraining_config["name"],
-            data_files=pretraining_config["data_files"],
-        )
+        ds = load_dataset_with_config(
+            pretraining_config,
+            cfg.hf_use_auth_token,
+            streaming=True,
+        )
+        if isinstance(ds, (DatasetDict, IterableDatasetDict)):
+            split = pretraining_config.get("split") or "train"
+            iter_dataset = ds[split]
+        else:
+            iter_dataset = ds

I can add a regression test that uses a dummy private dataset gate to ensure token propagation on streaming paths.

♻️ Duplicate comments (2)
examples/streaming/README.md (1)

20-27: SFT docs: mention pretrain_multipack_attn and clarify eval is not streamed (with local-caching tip)

Align with project docs/tests: eval is not streamed; recommend local caching. Also call out the attention mask flag for SFT with packing.

 - Explicit `streaming: true` flag for SFT datasets
-- Memory-efficient training on instruction datasets
-- Evaluation datasets are currently not streamed
+- Enable `pretrain_multipack_attn` when using `sample_packing` to avoid cross-sample attention
+- Memory-efficient training on instruction datasets
+- Evaluation datasets are not streamed; prefer smaller eval sets or local caching
src/axolotl/utils/data/sft.py (1)

141-146: Fix: Normalize cfg.datasets[0] to a DictDefault and ensure type/split are present.

As previously flagged, directly doing DictDefault(cfg.datasets[0]) breaks when the first entry is not a plain dict (e.g., Pydantic model). Also, missing type leads to None flowing into get_dataset_wrapper later. Normalize the shape and guard attributes.

Apply:

-            dataset_config = DictDefault(cfg.datasets[0])
-
-            # Ensure we have a split set - default to 'train' if not specified
-            if not hasattr(dataset_config, "split") or not dataset_config.split:
-                dataset_config.split = "train"
+            first = cfg.datasets[0]
+            # Normalize to a DictDefault mapping regardless of source type
+            if hasattr(first, "model_dump"):
+                dataset_config = DictDefault(first.model_dump())
+            elif isinstance(first, dict):
+                dataset_config = DictDefault(first)
+            else:
+                dataset_config = DictDefault(dict(first))
+
+            # Ensure `type` is present
+            if not getattr(dataset_config, "type", None):
+                dataset_config.type = getattr(first, "type", dataset_config.get("type"))
+
+            # Ensure split default
+            if not getattr(dataset_config, "split", None):
+                dataset_config.split = "train"

Happy to wire a small unit test covering dict, pydantic, and namedtuple inputs.

🧹 Nitpick comments (13)
examples/streaming/README.md (7)

8-10: Clarify preprocess deprecation and fix punctuation around “e.g.”

Make it explicit that preprocess is no longer needed due to deprecation of the iterable mode, and fix the “e.g.” punctuation.

-Run the following examples with e.g. `axolotl train examples/streaming/sft.yaml`; no
-`axolotl preprocess` required!
+Run the following examples with, e.g., `axolotl train examples/streaming/sft.yaml`. No
+`axolotl preprocess` step is required (the `preprocess --iterable` mode is deprecated).

13-15: Capitalize dataset name for consistency

“FineWeb-Edu” is typically capitalized; minor readability polish.

-Demonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset
-with SmolLM2-135M.
+Demonstrates streaming configuration for pretraining tasks using the FineWeb‑Edu dataset
+with SmolLM2‑135M.

16-19: Name the exact flag: use “pretrain_multipack_attn”

Call out the specific config knob to avoid ambiguity, matching the PR’s behavior change.

-- Multipack attention control to prevent cross-attention between packed sequences
+- Enable `pretrain_multipack_attn` to prevent cross-sample attention when packing

34-38: Optional: note scope of the buffer (CPU RAM) and tuning guidance

A brief qualifier helps users plan memory; safe to keep concise.

 - Controls buffer size for sample packing (default: 10,000)
-- Larger values improve packing efficiency but use more memory
+- Larger values improve packing efficiency but use more CPU RAM
 - Adjust based on available memory

39-42: Add reproducibility note for shuffling

Suggest mentioning shuffle_seed to aid deterministic experiments.

 - Enables shuffling of streaming datasets
-- Requires additional memory for shuffle buffer
+- Requires additional memory for shuffle buffer; set `shuffle_seed` for reproducibility

43-46: Reinforce attention masking requirement when packing

Place the guidance near the option users will toggle.

 ### `sample_packing`
 - Packs multiple samples into single sequences
 - Minimize per-step padding tokens
+- When packing, set `pretrain_multipack_attn: true` to block attention across sample boundaries

49-50: Broaden performance tips with eval caching and buffer tuning advice

These two bullets capture common pitfalls in streaming runs.

 - Download small / frequently-used datasets locally for better performance
-- Larger buffer sizes improve packing efficiency
+- Larger buffer sizes improve packing efficiency
+- Cache evaluation datasets locally to eliminate network variability
+- Sweep `streaming_multipack_buffer_size` (e.g., 2k/5k/10k) to find the best tradeoff for your RAM
src/axolotl/utils/data/sft.py (6)

63-66: Confirm routing: pretraining_dataset forces the streaming path, even if cfg.streaming is False.

Is this intentional? Older configs may set pretraining_dataset for other flows. If the intention is “streaming whenever pretraining_dataset is configured,” keep as-is; otherwise, guard with cfg.streaming to avoid surprising behavior.

If you want to require explicit streaming, consider:

-    if cfg.streaming or cfg.pretraining_dataset:
+    if cfg.streaming and cfg.pretraining_dataset:
         return _prepare_streaming_dataset(cfg, tokenizer, processor)

162-168: Evaluation in streaming mode currently disables iterable preprocessing; consider making it opt-in.

Hard-coding preprocess_iterable=False means test_datasets won’t stream, which can be surprising and memory-heavy for large eval sets. Propagate a config flag (e.g., cfg.eval_streaming) or default to cfg.streaming.

Suggested tweak:

-        _, eval_dataset, _ = _load_and_prepare_datasets(
+        _, eval_dataset, _ = _load_and_prepare_datasets(
             tokenizer,
             cfg,
             split="test",
             processor=processor,
-            preprocess_iterable=False,
+            preprocess_iterable=getattr(cfg, "eval_streaming", cfg.streaming),
         )

171-172: Deduplication log clarity.

Message is accurate: exact dedup isn’t available with streaming. Consider hinting at the non-streaming alternative (preprocess offline) for users encountering this log.


173-175: Prompters list is empty in some streaming paths—verify downstream expectations.

Standard paths return a non-empty prompters list; here you return [] (and earlier return a real list only in the non-packed streaming branch). If any trainer code assumes length to match datasets, this could break formatting/metrics hooks.

Option: thread the dataset_prompter from get_dataset_wrapper through wrap_streaming_dataset (or return a [None] sentinel instead of []) for API consistency.


210-210: Docstring mismatch: function now serves SFT as well.

The docstring says “for pretraining” but this function is used by SFT when sample_packing is enabled.

Apply:

-    """Load and prepare a streaming dataset for pretraining."""
+    """Load and prepare a streaming dataset (SFT or pretraining)."""

253-260: Optional: avoid temp file leakage for placeholder dataset.

NamedTemporaryFile(delete=False) never unlinks, leaving artifacts. If dispatch_batches ensures non-main ranks don’t iterate, a generator-based IterableDataset avoids touching disk.

Example (no diff since it changes implementation detail):

from datasets import IterableDataset

def _create_placeholder_dataset() -> IterableDataset:
    def gen():
        yield {"text": "lorem ipsum dolor sit amet"}
    return IterableDataset.from_generator(gen)

If you keep the CSV, consider os.unlink(f.name) after constructing the dataset (but ensure the dataset doesn’t need the file thereafter).

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 8492bd1 and 17013be.

📒 Files selected for processing (5)
  • docs/streaming.qmd (1 hunks)
  • examples/streaming/README.md (1 hunks)
  • examples/streaming/pretrain.yaml (1 hunks)
  • examples/streaming/sft.yaml (1 hunks)
  • src/axolotl/utils/data/sft.py (8 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • examples/streaming/pretrain.yaml
  • examples/streaming/sft.yaml
  • docs/streaming.qmd
🧰 Additional context used
🧬 Code graph analysis (1)
src/axolotl/utils/data/sft.py (2)
src/axolotl/utils/data/pretraining.py (1)
  • wrap_streaming_dataset (179-244)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
🪛 LanguageTool
examples/streaming/README.md

[grammar] ~3-~3: There might be a mistake here.
Context: ...ns for using Axolotl's streaming dataset functionality, which enables memory-effi...

(QB_NEW_EN)


[grammar] ~8-~8: There might be a mistake here.
Context: ...l train examples/streaming/sft.yaml; no axolotl preprocess` required! ### Pret...

(QB_NEW_EN)


[grammar] ~13-~13: There might be a mistake here.
Context: ...ning tasks using the fineweb-edu dataset with SmolLM2-135M. - Uses `pretraining_...

(QB_NEW_EN)


[grammar] ~24-~24: There might be a mistake here.
Context: ... streaming: true flag for SFT datasets - Memory-efficient training on instruction...

(QB_NEW_EN)


[grammar] ~25-~25: There might be a mistake here.
Context: ...ficient training on instruction datasets - Evaluation datasets are currently not st...

(QB_NEW_EN)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: preview
🔇 Additional comments (6)
examples/streaming/README.md (1)

30-33: LGTM: streaming option description is clear and accurate

Accurately reflects implicit enablement via pretraining_dataset.

src/axolotl/utils/data/sft.py (5)

12-13: Good addition: IterableDatasetDict support for split selection.

Importing IterableDatasetDict enables correct split handling for streaming dataset dicts below. This aligns with your later isinstance checks.


19-19: Importing wrap_streaming_dataset is appropriate for SFT streaming.

This consolidates the streaming encode/wrap logic and keeps SFT aligned with the pretraining path.


148-157: Non-packed streaming path via preprocess_iterable=True looks correct.

Returning early with total_num_steps=-1 (or cfg.max_steps) matches streaming semantics and avoids eager preprocessing.


348-355: Skipping long-seq handling under preprocess_iterable is correct.

This avoids materializing large transformed datasets when streaming. Packing remains gated as expected.


384-389: Split selection now supports IterableDatasetDict—good catch.

This aligns the streaming and non-streaming code paths and prevents KeyErrors on split resolution.

@djsaunde djsaunde mentioned this pull request Aug 26, 2025
5 tasks

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
src/axolotl/utils/data/sft.py (1)

139-151: Potential issue with unsafe attribute access.

The code assumes cfg.datasets[0] has a type attribute on line 146, which will fail if cfg.datasets[0] is a plain dict. This needs defensive handling.

Apply this fix to safely handle both dict and object types:

-            dataset_config = DictDefault(cfg.datasets[0])
+            # Safely extract and normalize the first dataset config
+            first_dataset = cfg.datasets[0]
+            if isinstance(first_dataset, dict):
+                dataset_config = DictDefault(first_dataset)
+            else:
+                dataset_config = DictDefault(first_dataset.model_dump() if hasattr(first_dataset, 'model_dump') else dict(first_dataset))

             # Ensure we have a split set - default to 'train' if not specified
             if not hasattr(dataset_config, "split") or not dataset_config.split:
                 dataset_config.split = "train"
🧹 Nitpick comments (2)
src/axolotl/utils/data/streaming.py (1)

1-1: Documentation clarification needed.

The module docstring should be updated to reflect the expanded scope beyond just streaming, as these functions now handle both streaming and pretraining datasets.

-"""Data handling specific to streaming datasets."""
+"""Data handling for streaming and pretraining datasets with multipack support."""
src/axolotl/utils/data/sft.py (1)

178-178: Consider using logging.warning for important user-facing messages.

Since deduplication is silently skipped for streaming datasets, using LOG.warning instead of LOG.info would better alert users to this limitation.

-        LOG.info("Deduplication not available for streaming datasets")
+        LOG.warning("Deduplication not available for streaming datasets")
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 17013be and aea97a3.

📒 Files selected for processing (4)
  • src/axolotl/utils/data/__init__.py (2 hunks)
  • src/axolotl/utils/data/sft.py (15 hunks)
  • src/axolotl/utils/data/streaming.py (4 hunks)
  • tests/test_streaming.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/test_streaming.py
  • src/axolotl/utils/data/init.py
🧰 Additional context used
🧬 Code graph analysis (2)
src/axolotl/utils/data/streaming.py (1)
src/axolotl/utils/collators/batching.py (1)
  • PretrainingBatchSamplerDataCollatorForSeq2Seq (200-227)
src/axolotl/utils/data/sft.py (2)
src/axolotl/utils/data/streaming.py (1)
  • wrap_streaming_dataset (179-247)
src/axolotl/utils/data/shared.py (2)
  • merge_datasets (525-565)
  • load_dataset_with_config (93-148)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: preview
🔇 Additional comments (15)
src/axolotl/utils/data/streaming.py (6)

20-26: Function name accurately reflects the expanded functionality.

Good job on renaming from encode_pretraining to encode_streaming to better represent the function's expanded use case beyond just pretraining.


179-210: Great implementation of multipack attention logic for SFT datasets.

The conditional logic correctly ensures that SFT datasets use multipack_attn=True for proper attention isolation between packed sequences. The comment clearly explains the reasoning, and setting cfg.micro_batch_size = 1 prevents double batching.


212-221: Clear documentation of unreachable code path.

The comment effectively explains why this code path is not reachable for SFT datasets and points to the correct location in sft.py.


224-226: Proper parameter naming for clarity.

The change from generic buffer_size to streaming_multipack_buffer_size improves code clarity by making the parameter's purpose explicit.


250-257: Function renamed appropriately to match module pattern.

The rename from encode_packed_pretraining to encode_packed_streaming maintains consistency with the other renamed functions.


285-286: Good cleanup - removed redundant deletion.

The removal of the duplicate del features["num_truncated_tokens"] line eliminates unnecessary redundancy.

src/axolotl/utils/data/sft.py (9)

12-12: Necessary import for streaming dataset support.

The addition of IterableDatasetDict import properly supports the new streaming functionality.


19-19: Import updated to match renamed module functions.

The import correctly reflects the renamed wrap_streaming_dataset function.


63-65: Clean routing logic for streaming vs standard datasets.

The conditional routing based on cfg.streaming or cfg.pretraining_dataset is clear and straightforward.


125-183: Well-structured streaming dataset preparation function.

The function properly handles both pretraining and SFT streaming cases with appropriate error handling for unsupported configurations. The early return for non-packed streaming datasets is efficient.


214-257: Clean implementation of streaming dataset loader.

The function properly handles dataset loading with appropriate placeholder creation for non-main processes and correct application of skip logic.


274-274: Parameter rename improves clarity.

The rename from preprocess_iterable to streaming better reflects the parameter's purpose in the context of streaming dataset support.

Also applies to: 283-283, 314-314, 326-326, 347-347, 379-379


355-355: Proper conditional logic for non-streaming operations.

The condition correctly prevents long sequence handling and dataset saving for streaming datasets, which is appropriate since streaming datasets are processed on-the-fly.


391-391: Type checking expanded for streaming support.

The addition of IterableDatasetDict to the type check properly handles streaming dataset dictionaries.


384-384: Streaming parameter properly propagated.

The addition of the streaming parameter to load_dataset_with_config ensures consistent behavior throughout the loading pipeline.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (4)
docs/streaming.qmd (4)

67-71: Spell out behavior for multiple datasets and determinism

Clarify that shuffling is over the merged stream, and call out seed requirements for reproducibility (important for streaming).

Apply this diff:

-When enabled, shuffles the streaming dataset using the buffer. This requires additional
-memory for the shuffle buffer.
+When enabled, shuffles the merged streaming dataset using the buffer (not per-dataset).
+This requires additional memory for the shuffle buffer. For reproducibility across runs
+and distributed workers, set a training seed.

47-57: Optional: surface the SFT attention isolation in the SFT example

Even if force-enabled, a short comment in the SFT snippet helps users discover the behavior in the most relevant place.

Apply this diff:

 streaming: true
 datasets:
   - path: tatsu-lab/alpaca
     type: alpaca
     split: train

 # Optionally, enable sample packing
 streaming_multipack_buffer_size: 10000
 sample_packing: true
+# Note: In SFT, attention isolation between packed samples is automatically enabled.

17-18: Add a short migration note for removed CLI flag and renamed options

Users migrating from preprocess-time iterable mode will look for a canonical mapping in this page.

Apply this diff to add a migration subsection right after the “Configuration” header:

 ## Configuration

+### Migration notes
+
+- The `axolotl preprocess --iterable` CLI flag has been removed; streaming now flows through the training path.
+- `pretrain_multipack_buffer_size` has been renamed to `streaming_multipack_buffer_size`.
+- In SFT, attention isolation for packed samples is force-enabled when `sample_packing: true`; you do not need to set `pretrain_multipack_attn` for SFT.
+

63-66: Reference schema for default multipack buffer size and offer a starting value

The docs should avoid hard-coding the default and instead point to the config schema, while still giving “10,000” as a practical starting point (default confirmed as 10_000 in the schema).

  • File: docs/streaming.qmd
  • Lines: 63–66

Proposed diff:

-Controls the buffer size for multipack streaming (default: 10,000). This determines how
-many samples are buffered before packing. Larger buffers can improve packing efficiency
-but use more memory.

+Controls the buffer size for multipack streaming. This determines how many samples are
+buffered before packing. Larger buffers can improve packing efficiency but use more host
+memory. The default is defined in the config schema (a good starting point is 10,000).

Verification:

  • In src/axolotl/utils/schemas/config.py (line 508):
    streaming_multipack_buffer_size: int | None = 10_000
  • Examples in examples/streaming/*.yaml also set streaming_multipack_buffer_size: 10000
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between aea97a3 and 9c66115.

📒 Files selected for processing (1)
  • docs/streaming.qmd (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
🔇 Additional comments (4)
docs/streaming.qmd (4)

7-16: Nice concise intro and “when to use” guidance

Clear and actionable rationale for streaming, including the immediate-start benefit and remote/local parity.


86-86: Link Verification Complete
The file docs/multipack.qmd exists alongside docs/streaming.qmd, so the relative link [documentation](multipack.qmd) will resolve correctly. No changes required.


111-114: Update example references to clickable links

We’ve verified that the files examples/streaming/pretrain.yaml and examples/streaming/sft.yaml exist in the repo and are reachable via ../examples/streaming/... from docs/streaming.qmd. Assuming Quarto will preserve these as static assets, you can turn them into clickable links:

• File checked:
• examples/streaming/pretrain.yaml
• examples/streaming/sft.yaml

Replace the block at lines 111–114 in docs/streaming.qmd with the following diff:

@@ docs/streaming.qmd
-See the `examples/streaming/` directory for complete configuration examples:
-- `pretrain.yaml`: Pretraining with streaming dataset
-- `sft.yaml`: Supervised fine-tuning with streaming
+See the examples in the repository:
+- [examples/streaming/pretrain.yaml](../examples/streaming/pretrain.yaml): Pretraining with streaming dataset
+- [examples/streaming/sft.yaml](../examples/streaming/sft.yaml): Supervised fine-tuning with streaming

If your site build cannot resolve repo-relative links, you can revert to the plain-text list.


31-41: Validation Completed: pretraining snippet keys are correct

The PretrainingDataset schema in src/axolotl/utils/schemas/datasets.py defines:

  • type: str | None = "pretrain"
  • split: str | None = "train"
  • text_column: str | None = "text"

These match the YAML snippet’s type: pretrain, split, and text_column keys. No changes needed.

Comment thread docs/streaming.qmd
Comment thread docs/streaming.qmd

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
src/axolotl/utils/collators/__init__.py (1)

1-1: Capitalize proper noun in docstring ("Mamba")

Minor clarity/style: capitalize the model name.

-"""Shared axolotl collators for multipacking, mamba, multimodal."""
+"""Shared axolotl collators for multipacking, Mamba, multimodal."""
src/axolotl/utils/schemas/validation.py (2)

63-76: Deprecation warning reads well; consider one-shot logging and pinning version in the message.

  • To avoid log spam in distributed runs, prefer LOG.warning_once here (used elsewhere in this file) or pass main_process_only=True if supported by your logger wrapper.
  • The TODO mentions 0.13.0, but the user-facing message doesn’t. If the target is firm, include it in the warning to make the migration window explicit.

Apply this minimal tweak:

-            LOG.warning(
-                "Setting `pretraining_dataset` without explicitly setting `streaming: "
-                "true` is deprecated. In a future release, streaming will not be "
+            LOG.warning_once(
+                "Setting `pretraining_dataset` without explicitly setting `streaming: "
+                "true` is deprecated and will change in Axolotl 0.13.0. Streaming will not be "
                 "automatically enabled when using pretraining_dataset. Please "
                 "explicitly set `streaming: true` in your configuration to maintain "
                 "current behavior."
             )

1112-1141: Validation rules are correct; tighten messages to reflect “non-zero” and align with streaming semantics.

  • The current condition treats 0 as allowed (good), but the error text reads like “not supported at all”. Clarify “non-zero val_set_size”.
  • The “max_steps must be set when using streaming datasets” rule is right. Consider hinting that num_train_epochs is insufficient for iterable datasets.

Apply these small message tweaks:

-            raise ValueError(
-                "val_set_size is not supported with pretraining_dataset. "
-                "Use test_datasets to specify evaluation datasets for pretraining."
-            )
+            raise ValueError(
+                "Non-zero val_set_size is not supported with pretraining_dataset. "
+                "Use test_datasets to specify evaluation datasets for pretraining."
+            )
...
-            raise ValueError(
-                "val_set_size is not supported with streaming datasets. "
-                "Use test_datasets to specify evaluation datasets when streaming is enabled."
-            )
+            raise ValueError(
+                "Non-zero val_set_size is not supported with streaming datasets. "
+                "Use test_datasets to specify evaluation datasets when streaming is enabled."
+            )
...
-            raise ValueError(
-                "max_steps must be set when using streaming datasets. "
-                "Trainer cannot infer dataset length for iterable datasets."
-            )
+            raise ValueError(
+                "max_steps must be set when using streaming datasets. "
+                "Trainer cannot infer dataset length for iterable datasets; "
+                "num_train_epochs is insufficient."
+            )

If helpful, I can add unit tests asserting:

  • val_set_size=None and =0 both pass with streaming/pretraining, but >0 raises.
  • streaming_multipack_buffer_size=0 is preserved (no unintended migration).
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 9c66115 and d9a56c4.

📒 Files selected for processing (3)
  • src/axolotl/utils/collators/__init__.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (3 hunks)
  • tests/test_streaming.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/test_streaming.py
🧰 Additional context used
🧬 Code graph analysis (1)
src/axolotl/utils/collators/__init__.py (2)
src/axolotl/utils/collators/batching.py (4)
  • BatchSamplerDataCollatorForSeq2Seq (129-155)
  • DataCollatorForSeq2Seq (12-125)
  • PretrainingBatchSamplerDataCollatorForSeq2Seq (200-227)
  • V2BatchSamplerDataCollatorForSeq2Seq (159-196)
src/axolotl/utils/collators/mamba.py (1)
  • MambaDataCollator (15-39)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: preview
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
🔇 Additional comments (2)
src/axolotl/utils/collators/__init__.py (2)

3-9: Re-export imports look good and align with the packaging intent

The curated import surface is clean and avoids circulars. No issues spotted.


11-17: No star-imports detected; add list[str] annotation to __all__

All consumer code imports the exact names you’ve listed (no import * usages were found), so defining __all__ won’t shrink any existing API surface. To satisfy strict type-checkers, annotate __all__ with a type hint:

--- a/src/axolotl/utils/collators/__init__.py
+++ b/src/axolotl/utils/collators/__init__.py
@@ lines 11-17
-__all__ = [
+__all__: list[str] = [
     "DataCollatorForSeq2Seq",
     "BatchSamplerDataCollatorForSeq2Seq",
     "V2BatchSamplerDataCollatorForSeq2Seq",
     "PretrainingBatchSamplerDataCollatorForSeq2Seq",
     "MambaDataCollator",
 ]

With this change, the public API remains intact and your code is fully type-annotated.

Comment thread src/axolotl/utils/schemas/validation.py
Comment thread docs/streaming.qmd
Comment thread src/axolotl/utils/data/sft.py Outdated

@NanoCode012 NanoCode012 left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work refactoring out preprocess_iterable.

Comment thread src/axolotl/cli/preprocess.py
Comment thread docs/streaming.qmd
Comment thread docs/streaming.qmd
Comment thread src/axolotl/utils/data/sft.py Outdated
Comment thread src/axolotl/utils/data/sft.py
Comment thread src/axolotl/utils/schemas/config.py Outdated
Comment thread src/axolotl/utils/schemas/validation.py
Comment thread docs/streaming.qmd Outdated
Comment thread docs/streaming.qmd Outdated
Comment thread docs/streaming.qmd Outdated
Comment thread docs/streaming.qmd Outdated
Comment thread src/axolotl/utils/schemas/config.py Outdated
@djsaunde djsaunde merged commit 231a67e into main Sep 2, 2025
15 checks passed
@djsaunde djsaunde deleted the streaming-v3 branch September 2, 2025 16:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants