Streaming SFT support#3101
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughAdds streaming-dataset support and docs; deprecates the preprocess-time iterable flag and routes streaming via the train path. Renames pretraining encoding/wrapping APIs to streaming equivalents, removes ConstantLengthDataset, refactors dataset loading/tokenization for streaming-first behavior, updates configs/validation, examples, and tests. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested labels
Suggested reviewers
✨ Finishing Touches🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
|
📖 Documentation Preview: https://68b70d793f71b914b66994cb--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit 4d1a47b |
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/axolotl/utils/data/pretraining.py (1)
27-37: Guard against missing eos/pad tokens and clarify truncation intent.Relying on tokenizer.eos_token_id/pad_token_id without checks can explode for tokenizers lacking them (common in base LLaMA variants pre-patch). Also, max_length=max_tokens - 2 assumes we’ll always append 2 tokens per sample; please assert this contract explicitly.
Apply:
res = tokenizer( examples[text_column], truncation=True, max_length=max_tokens - 2, add_special_tokens=True, ) +# Validate special tokens (fail fast with clear guidance) +if tokenizer.eos_token_id is None or tokenizer.pad_token_id is None: + raise ValueError( + "encode_streaming requires eos_token_id and pad_token_id. " + "Ensure tokenizer has these set (e.g., tokenizer.add_special_tokens or config)." + )src/axolotl/utils/data/sft.py (1)
181-194: KeyError risk onskip; also propagatetext_columnif specified.When
cfg.pretraining_datasetis a list of dicts,skipis optional. Accessingconfig["skip"]will raise if absent. Use.getwith default 0. Includingtext_columnimproves parity with the simple path and downstream wrappers.- config = cfg.pretraining_dataset[0] - return DictDefault( - { - "path": config["path"], - "name": config["name"], - "skip": config["skip"], - "split": config.get("split", "train"), - "data_files": config.get("data_files"), - "type": config.get("type", "pretrain"), - } - ) + config = cfg.pretraining_dataset[0] + return DictDefault( + { + "path": config["path"], + "name": config.get("name"), + "skip": config.get("skip", 0), + "split": config.get("split", "train"), + "data_files": config.get("data_files"), + "type": config.get("type", "pretrain"), + "text_column": config.get("text_column"), + } + )
🧹 Nitpick comments (48)
src/axolotl/prompt_tokenizers.py (2)
76-81: Preserve output shape invariants on early-return by including labels=[]When returning early for empty prompts, the BatchEncoding lacks a "labels" key, whereas the non-empty path always sets it. Some downstream code assumes "labels" exists. Add an empty labels array to the early-return object to keep the contract consistent.
- empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) + empty = BatchEncoding(data={"input_ids": [], "attention_mask": [], "labels": []})
76-81: Confirm that LOG.warning_once actually de-duplicates emissions; otherwise this is a no-op renameSwitching to warning_once is good only if it really logs once. The current implementation in axolotl.utils.logging.warning_once appears to just proxy to warning with no caching/dedup, so this won’t reduce log spam. Either (a) implement the dedup logic in the logger, or (b) revert to LOG.warning to avoid misleading semantics.
If you want a minimal implementation in src/axolotl/utils/logging.py:
# Example patch: implement once-only by message text import logging from functools import lru_cache class AxoLogger(logging.Logger): @lru_cache(maxsize=4096) def _warn_once_cache(self, msg: str): return True def warning_once(self, msg, *args, **kwargs): if self._warn_once_cache(str(msg)): super().warning(msg, *args, **kwargs)Then ensure get_logger returns AxoLogger instances. I can provide a full patch if helpful.
src/axolotl/datasets.py (5)
4-7: Nit: fix grammar in module docstring (“Let’s” vs “Lets”)Tiny polish for user-facing docs.
-We want this to be a wrapper for an existing dataset that we have loaded. Lets use the +We want this to be a wrapper for an existing dataset that we have loaded. Let's use the concept of middlewares to wrap each dataset. We'll use the collators later on to pad the datasets.
28-35: Type tightening: keep_in_memory should be bool, not Optional[bool]datasets.map expects a bool for keep_in_memory. Annotating as Optional increases the chance of passing None through. Make the type bool and default False.
- process_count: int | None = None, - keep_in_memory: bool | None = False, + process_count: int | None = None, + keep_in_memory: bool = False,
45-51: Normalize remove_columns input to a listFor consistency with the IterableDataset branch—and to avoid any surprises with dict_keys views—cast features to list before passing to remove_columns.
- features = dataset.features.keys() + features = list(dataset.features.keys())
62-69: Optional: make batch size configurable and consistent across streaming/non-streaming pathsNon-iterable path sets batch_size=1_000 when supports_batched is True; the IterableDataset path leaves it unset (uses library default). Consider threading a common batch_size from the tokenizer or kwargs to keep behavior consistent and tuneable.
Example minimal change (mirrors the non-streaming 1_000 default):
return dataset.map( self.prompt_tokenizer.tokenize_prompt, num_proc=self.process_count, remove_columns=features, keep_in_memory=self.keep_in_memory, desc="Tokenizing Prompts", - **map_kwargs, + **({"batch_size": 1_000} | map_kwargs), )And in wrap_dataset_for_tokenized_prompt (IterableDataset):
- if prompt_tokenizer.supports_batched: - map_kwargs["batched"] = True + if prompt_tokenizer.supports_batched: + map_kwargs["batched"] = True + map_kwargs["batch_size"] = 1_000If you prefer, expose batch_size via kwargs or a field on PromptTokenizingStrategy.
77-86: IterableDataset.features may be absent/partial; guard remove_columns computationSome streaming sources don’t define features eagerly. If features is None, list(dataset.features.keys()) will fail. Add a fallback to infer columns from a sample row or skip remove_columns when unknown.
- features = list(dataset.features.keys()) + features = list(dataset.features.keys()) if dataset.features is not None else None return dataset.map( prompt_tokenizer.tokenize_prompt, - remove_columns=features, + **({} if features is None else {"remove_columns": features}), **map_kwargs, )If you want me to add a safe “peek 1 row” helper for IterableDataset to infer columns, I can draft it.
tests/e2e/integrations/test_kd.py (1)
28-28: Switching to qwen3 chat template looks consistent; consider aligning test naming.The change to
chat_template: "qwen3"matches the Qwen3 base model and dataset artifacts used in this test. To reduce confusion, consider renamingtest_llama_kd(Line 82) to reflect Qwen3, since the model/template stack is no longer LLaMA-based. Non-blocking.src/axolotl/utils/data/utils.py (1)
193-207: Make streaming detection explicit and reduce false positives in warnings.The current guard infers streaming via the presence/absence of
column_names, and warns about missinginput_idsbefore checking for a streaming dataset. This can (a) warn in legitimate pre-tokenization/streaming paths and (b) miss the simpler/clearer type check viaIterableDataset.Suggest handling streaming first and softening the message for non-
input_idsdatasets (e.g., reward modeling, pre-tokenization). This preserves behavior while improving clarity.Apply this focused refactor:
- if ( - hasattr(dataset, "column_names") - and dataset.column_names - and "input_ids" not in dataset.column_names - ): - LOG.warning( - "Dataset does not contain 'input_ids' column. Skip drop long seq. This is " - "expected for reward modeling." - ) - return dataset - elif not hasattr(dataset, "column_names") or dataset.column_names is None: - LOG.info( - "Dataset is streaming (IterableDataset), skipping long sequence handling" - ) - return dataset + # Prefer explicit streaming detection first + if isinstance(dataset, IterableDataset) and not getattr(dataset, "column_names", None): + LOG.info( + "Dataset is streaming (IterableDataset) without a column schema; skipping long-sequence handling." + ) + return dataset + # If we have a schema but no input_ids, don't attempt long-seq handling + if getattr(dataset, "column_names", None) and "input_ids" not in dataset.column_names: + LOG.warning( + "Dataset does not contain 'input_ids'. Skipping long-sequence handling (common for reward modeling or pre-tokenization stages)." + ) + return datasetsrc/axolotl/cli/args.py (1)
17-23: Use a non-optional bool for a deprecated flag.
iterableis now deprecated and has a concrete default. Making itbool(notOptional[bool]) avoids a tri-state and simplifies downstream checks.Apply this small type tweak:
- iterable: Optional[bool] = field( + iterable: bool = field( default=False, metadata={ "help": ( "[DEPRECATED] No longer supported. For streaming datasets, use " "'axolotl train' and set 'streaming: true' in your YAML config, or " "pass --streaming instead in the CLI." ) }, )src/axolotl/cli/preprocess.py (2)
38-46: Return a non-zero exit code when an unsupported flag is used.Currently we log an error and return, which yields exit code 0 from the CLI. This can mask misuse in scripts/CI. Prefer raising
SystemExit(2)after logging so callers can detect the failure.Apply this minimal change:
if cli_args.iterable: LOG.error( "The --iterable CLI argument for 'axolotl preprocess' is no longer " "supported. For training, set 'streaming: true' in your YAML config or " "pass '--streaming' in your 'axolotl train' command for on-the-fly " "preprocessing." ) - return + raise SystemExit(2)
50-52: Minor consistency nit: unify CLI name formatting.Elsewhere we tend to wrap CLI names with backticks in messages. Consider switching
"Run the 'axolotl train' CLI..."to backticks for consistency (or standardize repository-wide).src/axolotl/utils/schemas/validation.py (2)
343-363: Great deprecation shim; use warning_once to avoid log spam.The migration from
pretrain_multipack_buffer_sizetostreaming_multipack_buffer_sizeis handled well. To reduce repeated warnings across multiple validations/parses, preferLOG.warning_once(already used elsewhere in this module).Apply this small change:
- LOG.warning( + LOG.warning_once( "`pretrain_multipack_buffer_size` is deprecated. Please use `streaming_multipack_buffer_size` instead." )
1098-1126: Validation rules are solid; make val_set_size check robust to 0/0.0.The truthiness check handles most cases, but being explicit about zero avoids edge cases if a config loader ever passes
"0"as a string. Optional nit.Suggested tweak:
- if data.get("pretraining_dataset") and data.get("val_set_size"): + val = data.get("val_set_size") + if data.get("pretraining_dataset") and val not in (None, 0, 0.0, "0", "0.0"): raise ValueError( "val_set_size is not supported with pretraining_dataset. " "Use test_datasets to specify evaluation datasets for pretraining." ) - if data.get("streaming") and data.get("val_set_size"): + val = data.get("val_set_size") + if data.get("streaming") and val not in (None, 0, 0.0, "0", "0.0"): raise ValueError( "val_set_size is not supported with streaming datasets. " "Use test_datasets to specify evaluation datasets when streaming is enabled." )And keeping the
max_stepsrequirement for streaming is the right call. LGTM there.src/axolotl/utils/data/pretraining.py (8)
37-42: Non-concatenate path likely needs label masking for PAD (and possibly BOS) consistency.When concatenate=False, labels mirror input_ids (no -100 masking for pad). This diverges from the concatenated path and can hurt loss. Consider masking padding (and optionally BOS depending on training objective).
Example:
if not concatenate: - return { - "input_ids": [seq.tolist() for seq in input_ids], - "labels": [seq.tolist() for seq in targets], - "attention_mask": [seq.tolist() for seq in attention_mask], - } + masked_labels = [] + for ids, mask in zip(input_ids, attention_mask): + # ignore PAD tokens in loss + lbl = ids.clone() + lbl[mask == 0] = -100 + masked_labels.append(lbl) + return { + "input_ids": [seq.tolist() for seq in input_ids], + "labels": [seq.tolist() for seq in masked_labels], + "attention_mask": [seq.tolist() for seq in attention_mask], + }
131-165: Pad leftovers without a while loop.The while loop always executes once (padding to exact max_tokens in a single cat). This can be simplified and made faster.
- if buffer_input_ids.numel() > 0: # for any leftover tokens - while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size - buffer_input_ids = torch.cat( ... ) - buffer_labels = torch.cat( ... ) - buffer_attention_mask = torch.cat( ... ) + if buffer_input_ids.numel() > 0: # for any leftover tokens + pad_len = max_tokens - buffer_input_ids.numel() + if pad_len > 0: + buffer_input_ids = torch.cat( + (buffer_input_ids, torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long)), dim=0 + ) + buffer_labels = torch.cat( + (buffer_labels, torch.full((pad_len,), -100, dtype=torch.long)), dim=0 + ) + buffer_attention_mask = torch.cat( + (buffer_attention_mask, torch.zeros((pad_len,), dtype=torch.long)), dim=0 + )
175-176: Improve debug log context.LOG.debug(len(...)) without a message makes tracing logs painful.
-LOG.debug(len(ret["input_ids"])) +LOG.debug("encode_streaming produced %d sequences of length %d", len(ret["input_ids"]), max_tokens)
208-210: Avoid mutating cfg.micro_batch_size as a side effect.Overwriting cfg at runtime can surprise downstream components, logs, or callbacks that rely on the configured value. Since you already captured the original micro_batch_size in the functools.partial, prefer using a local override for the data loader only.
-# Set this to 1 so downstream data_loader doesn't try to increase the batch size again -cfg.micro_batch_size = 1 +# Avoid mutating cfg; keep the actual data_loader batch size at 1 via the collator/sampler path +# (If a consumer depends on cfg.micro_batch_size later, consider introducing cfg._effective_loader_batch_size)If mutation is required for compatibility, at least stash-and-restore around dataset construction.
220-224: Shuffle buffer sized to streaming_multipack_buffer_size — OK but document memory trade-offs.Using a large shuffle buffer here can spike RAM. Consider logging the effective buffer size on startup for operator awareness.
- if cfg.shuffle_merged_datasets: + if cfg.shuffle_merged_datasets: + LOG.info("Shuffling streaming dataset with buffer_size=%d and seed=%s", + cfg.streaming_multipack_buffer_size, cfg.seed) dataset = dataset.shuffle( seed=cfg.seed, buffer_size=cfg.streaming_multipack_buffer_size )
241-243: Batch size for map equals streaming_multipack_buffer_size — sanity-check defaults.This couples mapping batch size to the packing buffer. Reasonable, but consider capping to a safe upper bound (e.g., min(buffer_size, 8192)) or exposing a separate encode_batch_size for fine control.
No code diff necessary if you prefer current behavior; a config knob would suffice.
282-285: Duplicate deletion of num_truncated_tokens.The key removal appears twice.
- if "num_truncated_tokens" in features: - del features["num_truncated_tokens"] if "num_truncated_tokens" in features: del features["num_truncated_tokens"]
277-297: Return python lists for map-compatibility and parity with encode_streaming.encode_streaming returns lists of ints; encode_packed_streaming returns lists of Tensors. HF streaming map often tolerates tensors, but returning native lists avoids surprises and aligns keys.
- chunked_data[feature].append(collated_features[feature].squeeze(0)) + chunked_data[feature].append( + collated_features[feature].squeeze(0).tolist() + )Also consider sorting keys to ensure deterministic column order.
tests/test_data.py (3)
14-17: Rename test class/docstring for clarity.Now that the subject is encode_streaming, the class/docstring name “EncodePretraining” is misleading.
-class TestEncodePretraining(unittest.TestCase): - """ - test class for encode pretraining and md5 helper - """ +class TestEncodeStreaming(unittest.TestCase): + """Tests for encode_streaming and the md5 helper."""
32-59: Add a test for concatenate=False and enforce no double-EOS.The current test only covers the concatenated path. Add a case for concatenate=False and one that checks we don’t produce double EOS when the tokenizer already adds EOS.
Happy to draft the additional tests if you share your preferred tokenizer fixture constraints.
21-29: Offline tokenizer setup is good; consider pinning a tiny tokenizer for speed.huggyllama/llama-7b is large; even offline, tokenizer files are heavier than needed. A tiny tokenizer (e.g., a small BPE from a toy model) would speed up tests.
If swapping is non-trivial, keep as is.
examples/streaming/README.md (4)
3-4: Tighten wording.Minor grammar/style polish.
-This directory contains example configurations for using Axolotl's streaming dataset -functionality, which enables memory-efficient training with large datasets. +This directory contains example configurations for Axolotl’s streaming datasets, +which enable memory‑efficient training on large datasets.
10-16: Clarify model/dataset phrasing and bullets.-Demonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset -with SmolLM2-135M. +Demonstrates a streaming configuration for pretraining on FineWeb‑Edu +using the SmolLM2‑135M base model. -- Uses `pretraining_dataset` configuration for automatic streaming -- Multipack attention control to prevent cross-attention between packed sequences -- Buffer size configuration for memory management +- Automatic streaming via `pretraining_dataset` +- Multipack attention to prevent cross‑attention between packed sequences +- Tune buffer size for memory/throughput trade‑offs
21-24: Clarify SFT bullets.-- Explicit `streaming: true` flag for SFT datasets -- Evaluation dataset handling with streaming -- Memory-efficient training on instruction datasets +- Enable with `streaming: true` for SFT datasets +- Evaluation datasets are loaded normally (not streamed) +- Memory‑efficient SFT on instruction datasets
60-62: Add a tip about local HF cache.Mention HF_DATASETS_CACHE to minimize network bottlenecks.
-- Download small / frequently-used datasets locally for better performance +- Cache datasets locally (HF_DATASETS_CACHE) for better performancedocs/streaming.qmd (4)
30-43: Pretraining auto-streaming section reads well. Add a one-liner about “no preprocess needed.”Matches the PR narrative and examples; adding an explicit note will reduce questions.
For pretraining tasks, streaming is automatically enabled when using `pretraining_dataset`: ```yaml @@ sample_packing: true+You don’t need to run
axolotl preprocesswhen using streaming pretraining.--- `83-86`: **Clarify multipack attention semantics between SFT and pretraining.** Make explicit that SFT isolates attention automatically; pretraining uses pretrain_multipack_attn. ```diff -# For SFT: attention is automatically isolated between packed samples -# For pretraining: control with pretrain_multipack_attn +# For SFT: attention is automatically isolated between packed samples (always enabled) +# For pretraining: enable isolation via pretrain_multipack_attn
106-110: Evaluation not streamed — good. Add a pointer to why (determinism).-Evaluation datasets are not streamed to ensure consistent evaluation metrics. They're -loaded normally even when training uses streaming. +Evaluation datasets are not streamed to ensure deterministic, comparable metrics. +They’re loaded normally even when training uses streaming.
61-73: Consider warning callouts for memory knobs.This section is great; a short caution callout would stand out in rendered docs.
Add:
::: {.callout-warning} Large `streaming_multipack_buffer_size` and dataset shuffling can significantly increase RAM usage. Start smaller and scale up based on headroom. :::tests/test_packed_pretraining.py (1)
79-86: Future-proof the test against cfg.micro_batch_size mutation inside wrap_streaming_dataset.This test correctly captures original_bsz before calling wrap_streaming_dataset, relying on the fact that wrap_streaming_dataset sets cfg.micro_batch_size = 1 only after capturing the prior value into encode’s batch_size. To make the intent explicit and guard against future refactors that might change that order, add a quick assertion after the call.
Apply this diff:
original_bsz = cfg.micro_batch_size train_dataset = wrap_streaming_dataset( dataset, tokenizer_huggyllama, cfg, ds_wrapper_partial, ) + # wrap_streaming_dataset flattens packed samples into a single item and sets + # micro_batch_size to 1 for the downstream DataLoader. Keep this invariant visible. + assert cfg.micro_batch_size == 1examples/streaming/sft.yaml (4)
10-13: Avoid underscores in numeric literals in YAML for maximum loader compatibility.Some YAML parsers (and older loader stacks) don’t accept underscore-separated numerals. Switching 10_000 → 10000 avoids parsing ambiguity.
-streaming_multipack_buffer_size: 10_000 +streaming_multipack_buffer_size: 10000
24-26: Clarify packing expectations with micro_batch_size=1.With sample_packing: true and micro_batch_size: 1, packing still works (it packs multiple short samples into a single sequence), but users sometimes expect “flattened” mega-batches from micro_batch_size>1. A brief comment in this example could reduce confusion.
sample_packing: true flash_attention: true +# micro_batch_size controls per-step item count; packing still groups multiple short samples +# into each sequence even when this is 1. micro_batch_size: 1Also applies to: 19-22
45-50: Explicitly set WandB fields to null or remove them to avoid accidental init via environment.Empty keys are interpreted as null by most loaders, but being explicit (or omitting them) prevents accidental WandB initialization when environment variables are present.
-# Weights & Biases (optional) -wandb_project: -wandb_entity: -wandb_watch: -wandb_name: -wandb_log_model: +# Weights & Biases (optional) — leave commented out or set explicitly to null when needed +# wandb_project: null +# wandb_entity: null +# wandb_watch: null +# wandb_name: null +# wandb_log_model: null
19-22: Hardware-sensitive knobs in an example config.flash_attention: true and tf32: true are great defaults on recent NVIDIA GPUs, but they can surprise users on unsupported hardware. A one-line comment helps set expectations.
flash_attention: true +## Requires compatible GPU/software stack; set to false if you see kernel errors. ... tf32: true +## Effective on Ampere+ GPUs; harmless on others but can be disabled if desired.Also applies to: 35-37
tests/e2e/test_streaming.py (2)
31-33: pretrain_multipack_attn is unused in SFT configs.In the streaming SFT path (datasets provided; no pretraining_dataset), wrap_streaming_dataset forces multipack_attn=True and ignores pretrain_multipack_attn. Dropping it reduces confusion.
- "pretrain_multipack_attn": sample_packing, "streaming_multipack_buffer_size": 10000, "dataset_processes": 1,
51-53: Consolidate Optimizer Configuration Across TestsGiven that
"adamw_torch_fused"is used in dozens of test files (over 90 occurrences found), updating each instance individually is error-prone and hard to maintain. Instead, centralize the optimizer choice and introduce a fallback for environments where the fused variant isn’t available. For example:
Define a helper in your test utilities (e.g.
tests/utils/optimizers.py):def get_default_optimizer_name(): try: # this will raise or be unavailable on CPU-only builds import torch torch._C._jit_get_executor_state() # dummy check for fused support return "adamw_torch_fused" except Exception: return "adamw_torch"In your tests or builders, replace literal strings with:
from tests.utils.optimizers import get_default_optimizer_name ... "optimizer": get_default_optimizer_name(),This approach:
- Ensures CI portability without touching each test.
- Keeps performance benefits on CUDA-enabled runners.
- Simplifies future changes to optimizer defaults.
Locations requiring this refactor (non-exhaustive):
• tests/test_packed_dataset.py:66
• tests/e2e/test_profiler.py:39
• tests/e2e/test_streaming.py:51
• tests/core/test_builders.py:39
• …and all other instances of"optimizer": "adamw_torch_fused"undertests/src/axolotl/utils/data/__init__.py (2)
4-6: Provide backward-compat aliases for renamed APIs to ease downstream migration.This rename is reasonable, but external users may still import encode_pretraining/wrap_pretraining_dataset. Keep soft aliases with a deprecation warning for a release or two.
-from axolotl.utils.data.pretraining import ( - encode_streaming, - wrap_streaming_dataset, -) +import warnings +from axolotl.utils.data.pretraining import encode_streaming, wrap_streaming_dataset + +# Backward-compat (deprecated): keep old names alive with warnings +def encode_pretraining(*args, **kwargs): + warnings.warn( + "encode_pretraining is deprecated; use encode_streaming instead.", + DeprecationWarning, + stacklevel=2, + ) + return encode_streaming(*args, **kwargs) + +def wrap_pretraining_dataset(*args, **kwargs): + warnings.warn( + "wrap_pretraining_dataset is deprecated; use wrap_streaming_dataset instead.", + DeprecationWarning, + stacklevel=2, + ) + return wrap_streaming_dataset(*args, **kwargs)
14-21: Add deprecated aliases to__all__for backward compatibilityI ran the following search to ensure there are no lingering references to the old function names outside of the shim in
__init__.py:rg -nP '\b(wrap_pretraining_dataset|encode_pretraining)\b' -g '!src/axolotl/utils/data/__init__.py'No matches were found, confirming it’s safe to introduce these deprecated exports without breaking internal imports.
Apply this optional refactor to include the deprecated names:
__all__ = [ "encode_streaming", "wrap_streaming_dataset", + # Deprecated exports (remove in a future major) + "encode_pretraining", + "wrap_pretraining_dataset", "prepare_preference_datasets", "get_dataset_wrapper", "prepare_datasets", "md5", ]src/axolotl/utils/schemas/config.py (2)
492-499: Mark pretrain_multipack_buffer_size as deprecated in the schema and auto-migrate to streaming_multipack_buffer_size.The comment notes deprecation, but adding formal deprecation metadata plus an auto-migration guard reduces user friction and improves generated schemas.
- # Deprecated: Use streaming_multipack_buffer_size instead - pretrain_multipack_buffer_size: int | None = None + # Deprecated: Use streaming_multipack_buffer_size instead + pretrain_multipack_buffer_size: int | None = Field( + default=None, + deprecated="Use streaming_multipack_buffer_size instead", + json_schema_extra={ + "description": "Deprecated: Use streaming_multipack_buffer_size instead" + }, + )Additionally, add a before-validator to auto-migrate the value when provided:
@@ class AxolotlInputConfig( @@ ): @@ model_config = {"populate_by_name": True} @@ + @model_validator(mode="before") + @classmethod + def migrate_pretrain_buffer_size(cls, data): + # Migrate deprecated field when the new field is unset + if ( + data.get("pretrain_multipack_buffer_size") is not None + and data.get("streaming_multipack_buffer_size") is None + ): + data["streaming_multipack_buffer_size"] = data[ + "pretrain_multipack_buffer_size" + ] + LOG.warning( + "pretrain_multipack_buffer_size is deprecated; " + "using its value for streaming_multipack_buffer_size." + ) + return data
507-509: Guardrails for streaming_multipack_buffer_size.Consider validating that the buffer is positive and not absurdly large relative to available RAM to prevent OOMs during dataset.map on big corpora.
Option A (schema-level constraint): add ge=1 and a brief description.
- streaming_multipack_buffer_size: int | None = 10_000 + streaming_multipack_buffer_size: int | None = Field( + default=10_000, + ge=1, + json_schema_extra={"description": "Batch size for streaming map/packing buffer"} + )Option B (validator): cap or warn when > 1e6.
examples/streaming/pretrain.yaml (2)
15-18: Consider setting a deterministic seed.Adding
seedmakes shuffling and sampling reproducible across runs.# Training configuration +seed: 42 max_steps: 1000 output_dir: ./outputs/smollm2-135m-pretrain-streaming
40-45: Optional: add eval to monitor training.If you want on-the-fly validation in this example, add a small
test_datasetsblock and seteval_sample_packingaccordingly.tests/test_streaming.py (1)
34-37: Make log assertion resilient to message ordering.Indexing
cm.output[0]is brittle. Assert against any captured warning.- with self.assertLogs("axolotl.utils.schemas.validation", level="WARNING") as cm: - validated_cfg = validate_config(cfg_old) - self.assertIn("pretrain_multipack_buffer_size` is deprecated", cm.output[0]) + with self.assertLogs("axolotl.utils.schemas.validation", level="WARNING") as cm: + validated_cfg = validate_config(cfg_old) + self.assertTrue( + any("pretrain_multipack_buffer_size` is deprecated" in msg for msg in cm.output), + "Expected deprecation warning for pretrain_multipack_buffer_size" + )src/axolotl/utils/data/sft.py (1)
208-211: Nit: variable naming.
pretraining_confignow covers SFT-with-streaming too. Consider renaming todataset_configfor clarity (no functional change).
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (23)
_quarto.yml(1 hunks)docs/streaming.qmd(1 hunks)examples/streaming/README.md(1 hunks)examples/streaming/pretrain.yaml(1 hunks)examples/streaming/sft.yaml(1 hunks)src/axolotl/cli/args.py(1 hunks)src/axolotl/cli/preprocess.py(1 hunks)src/axolotl/common/datasets.py(0 hunks)src/axolotl/datasets.py(1 hunks)src/axolotl/prompt_tokenizers.py(1 hunks)src/axolotl/utils/data/__init__.py(2 hunks)src/axolotl/utils/data/pretraining.py(3 hunks)src/axolotl/utils/data/sft.py(8 hunks)src/axolotl/utils/data/shared.py(0 hunks)src/axolotl/utils/data/utils.py(1 hunks)src/axolotl/utils/schemas/config.py(1 hunks)src/axolotl/utils/schemas/validation.py(2 hunks)tests/e2e/integrations/test_kd.py(1 hunks)tests/e2e/test_streaming.py(1 hunks)tests/test_data.py(2 hunks)tests/test_packed_dataset.py(0 hunks)tests/test_packed_pretraining.py(2 hunks)tests/test_streaming.py(1 hunks)
💤 Files with no reviewable changes (3)
- src/axolotl/common/datasets.py
- src/axolotl/utils/data/shared.py
- tests/test_packed_dataset.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-08-08T07:22:40.131Z
Learnt from: winglian
PR: axolotl-ai-cloud/axolotl#3038
File: examples/slurm/axolotl.slurm:16-16
Timestamp: 2025-08-08T07:22:40.131Z
Learning: In Axolotl (PR #3038), the preprocess codepath sets AXOLOTL_IS_PREPROCESS internally, so external scripts (e.g., examples/slurm/axolotl.slurm) need not export it for the early-return in src/axolotl/utils/data/sft.py to trigger.
Applied to files:
src/axolotl/cli/preprocess.py
🧬 Code graph analysis (8)
tests/test_streaming.py (4)
src/axolotl/utils/dict.py (1)
DictDefault(6-38)src/axolotl/utils/data/sft.py (2)
_prepare_streaming_dataset(125-176)prepare_datasets(48-65)src/axolotl/utils/config/__init__.py (1)
validate_config(259-303)src/axolotl/utils/data/pretraining.py (1)
wrap_streaming_dataset(179-244)
src/axolotl/prompt_tokenizers.py (1)
src/axolotl/utils/logging.py (1)
warning_once(31-39)
tests/test_data.py (3)
src/axolotl/utils/data/pretraining.py (1)
encode_streaming(20-176)src/axolotl/utils/data/utils.py (1)
md5(73-78)tests/test_datasets.py (1)
tokenizer(31-35)
tests/e2e/test_streaming.py (3)
src/axolotl/utils/config/__init__.py (2)
normalize_config(97-237)validate_config(259-303)src/axolotl/utils/dict.py (1)
DictDefault(6-38)tests/e2e/utils.py (2)
check_model_output_exists(162-183)check_tensorboard(140-159)
tests/test_packed_pretraining.py (2)
src/axolotl/utils/data/wrappers.py (1)
get_dataset_wrapper(57-131)src/axolotl/utils/data/pretraining.py (1)
wrap_streaming_dataset(179-244)
src/axolotl/utils/data/__init__.py (1)
src/axolotl/utils/data/pretraining.py (2)
encode_streaming(20-176)wrap_streaming_dataset(179-244)
src/axolotl/utils/data/pretraining.py (1)
src/axolotl/utils/collators/batching.py (1)
PretrainingBatchSamplerDataCollatorForSeq2Seq(200-227)
src/axolotl/utils/data/sft.py (3)
src/axolotl/utils/data/lock.py (1)
FileLockLoader(17-68)src/axolotl/utils/data/pretraining.py (1)
wrap_streaming_dataset(179-244)src/axolotl/utils/dict.py (1)
DictDefault(6-38)
🪛 LanguageTool
examples/streaming/README.md
[grammar] ~3-~3: There might be a mistake here.
Context: ...ns for using Axolotl's streaming dataset functionality, which enables memory-effi...
(QB_NEW_EN)
[grammar] ~10-~10: There might be a mistake here.
Context: ...ning tasks using the fineweb-edu dataset with SmolLM2-135M. - Uses `pretraining_...
(QB_NEW_EN)
[grammar] ~21-~21: There might be a mistake here.
Context: ... streaming: true flag for SFT datasets - Evaluation dataset handling with streami...
(QB_NEW_EN)
[grammar] ~22-~22: There might be a mistake here.
Context: ...aluation dataset handling with streaming - Memory-efficient training on instruction...
(QB_NEW_EN)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: preview
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
🔇 Additional comments (24)
src/axolotl/utils/data/pretraining.py (6)
49-63: Potential double-EOS depending on tokenizer behavior.Some tokenizers add EOS when add_special_tokens=True; you then append another EOS here. That can insert double EOS boundaries. If intentional, fine; otherwise gate on whether the last token is already EOS.
Suggested defensive check:
- input_ids[i] = torch.cat( + # Avoid double-EOS if tokenizer already appended it + eos_pad = [tokenizer.eos_token_id, tokenizer.pad_token_id] + if len(input_ids[i]) > 0 and input_ids[i][-1].item() == tokenizer.eos_token_id: + eos_pad = [tokenizer.pad_token_id] # only pad if EOS present + input_ids[i] = torch.cat( ( input_ids[i], - torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]), + torch.tensor(eos_pad), ), dim=0, ) - targets[i] = torch.cat( + tgt_tail = [tokenizer.eos_token_id, -100] if len(eos_pad) == 2 else [-100] + targets[i] = torch.cat( ( targets[i], - torch.tensor([tokenizer.eos_token_id, -100]), + torch.tensor(tgt_tail), ), dim=0, )
186-198: Multipack attention policy: SFT always True — good; pretraining uses cfg.pretrain_multipack_attn.This is consistent with isolating attention for SFT. No action needed.
231-235: Peeking a streaming dataset to infer remove_columns can consume one sample.Iterating a streaming dataset here risks dropping the first example for subsequent consumers in some dataset implementations.
Prefer dataset.features if available, or use dataset.take(1) (supported by HF streaming) to avoid side effects:
- if dataset.features is None: - for first_row in dataset: - remove_columns = list(first_row.keys()) - break + if dataset.features is None: + try: + first = next(iter(dataset.take(1))) + except AttributeError: + first = next(iter(dataset)) + remove_columns = list(first.keys())Please test with an actual IterableDataset to confirm no sample loss.
259-266: Attention-mask handling during packing: confirm collator/trainer compatibility.With drop_attention_mask=multipack_attn, features may lack "attention_mask", so the custom collator’s per-pack attention segmentation won’t apply. That matches the FIXME comment (rely on position_ids workaround), but it means multipack_attn in collator is effectively ignored when True.
- Confirm the trainer path correctly enables the position-ids isolation when attention masks are dropped.
- Consider consolidating the logic to avoid conflicting signals (e.g., pass multipack_attn only to the trainer path when dropping masks).
268-275: Sampler params look sensible for streaming-packed flow.batch_max_len=batch_size * max_seq_length with batch_size=1 leverages the multipack sampler as intended.
70-70: Confirmed Python Floor ≥3.10 — no changes neededThe project’s pyproject.toml declares
requires-python = ">=3.10", so usingzip(..., strict=False)(which requires Python 3.10+) is safe._quarto.yml (1)
275-276: Good addition to Core Concepts navigation.docs/streaming.qmd placed right after dataset_preprocessing makes sense.
tests/test_data.py (1)
9-10: API import changes look correct.Switch to encode_streaming and keeping md5 import aligned with init exports is fine.
examples/streaming/README.md (1)
45-49: Confirm YAML numeric underscores are accepted by your loader.Some YAML parsers don’t accept 10_000 as a numeric literal. If your config loader handles it, great; otherwise quote it or use 10000.
If needed:
-streaming_multipack_buffer_size: 10_000 +streaming_multipack_buffer_size: 10000docs/streaming.qmd (1)
88-89: Link target looks correct for same-directory doc; consider adding sidebar link text for clarity.No code change needed; just confirming relative link works given both are under docs/.
tests/test_packed_pretraining.py (1)
12-12: Import rename to streaming API looks correct.The switch to import wrap_streaming_dataset aligns with the public API change and keeps this test on the streaming path. No issues.
tests/e2e/test_streaming.py (2)
18-21: Good coverage: parameterizing packing on/off.Exercising both sample_packing=True/False in the same test keeps the streaming path honest across packer modes.
60-66: End-to-end flow and postconditions look solid.validate_config + normalize_config before load_datasets/train, followed by artifact checks and a basic loss sanity check via TensorBoard is a pragmatic E2E signal.
src/axolotl/utils/schemas/config.py (2)
500-506: Confirmed:pretraining_sample_concatenationis still referenced correctly, no legacy aliases detectedAll occurrences of the relocated field use the exact same name, so the move won’t affect runtime behavior:
- Declaration remains at
src/axolotl/utils/schemas/config.py:500- Used in data loader:
src/axolotl/utils/data/pretraining.py:217- Used in model builder:
src/axolotl/core/builders/causal.py:428No references to any outdated or aliased field name were found.
507-507: No changes needed: streaming defaults to False end-to-endI’ve traced all the key code paths:
- In the schema,
streaming: bool | None = Noneallows users to omit the flag.- In data loading (
utils/data/shared.py), the function signature isso if the config value isdef load_dataset(..., streaming: bool = False, ...): ... return {"streaming": streaming, …}None, the defaultFalseis used.- Callers either pass an explicit boolean or rely on that default:
- RL pipeline in
utils/data/rl.pyalways calls withstreaming=False, regardless of config.- SFT pipeline in
utils/data/sft.pygates oncfg.streaming(whereNoneis falsy) and only requests streaming when explicitly set.- Validation in
utils/schemas/validation.pychecksdata.get("streaming")(whereNone→False) before applying streaming‐only rules.All entry points thus fall back deterministically to non‐streaming when
streamingis unset. No further changes are required here.examples/streaming/pretrain.yaml (1)
36-38: Double-check thatbf16: autois accepted by your config/Trainer path.Some stacks expect
bf16to be a boolean and infer automatically based on hardware without the"auto"sentinel. If your schema explicitly supports"auto", ignore this.tests/test_streaming.py (3)
143-165: Potential fragility: test depends on dict-likecfg.datasets[0]being attribute-accessible.Current implementation of
_prepare_streaming_dataset(see sft.py Lines 141–147) readscfg.datasets[0].type. Ifcfg.datasets[0]is a plain dict (as in this test), that attribute access would raise. I’m proposing a fix in sft.py to robustly handle both dicts and model instances. If you don’t adopt that fix, consider wrapping the config viavalidate_confighere before calling_prepare_streaming_dataset.
166-200: Good check: SFT forcesmultipack_attn=True.This verifies our intended invariant for SFT packing. Looks solid.
201-235: Good check: pretraining respectspretrain_multipack_attn.This keeps flexibility for pretraining while enforcing attention isolation for SFT.
src/axolotl/utils/data/sft.py (5)
63-65: Routing looks correct.Delegating to streaming path when either
cfg.streamingorcfg.pretraining_datasetis set is consistent with the new model.
171-176: Nice: clear messaging and step computation for streaming.Explicitly logging unavailability of dedup for streaming and returning
max_stepsor -1 is reasonable.
243-249: Wrapper switch towrap_streaming_datasetlooks good.This aligns with the streaming-first approach and matches the collator/encoding flow used elsewhere.
385-394: Good:IterableDatasetDictsupport on split selection.This unifies handling for both in-memory and streaming dict datasets.
349-356: Correct: skip long-seq processing for iterable preprocessing.Avoiding
handle_long_seq_in_datasetwhenpreprocess_iterable=Trueprevents loading the entire stream into memory.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (6)
examples/streaming/README.md (6)
8-9: Clarify run instructions, add both example commands, and note deprecation.Improves readability, avoids the semicolon splice, and surfaces the
--iterabledeprecation where users will see it.-Run the following examples with e.g. `axolotl train examples/streaming/sft.yaml`; no -`axolotl preprocess` required! +Run either of the following (no axolotl preprocess step required): + +```bash +axolotl train examples/streaming/pretrain.yaml +axolotl train examples/streaming/sft.yaml +``` + +Note: The `--iterable` flag in `axolotl preprocess` is deprecated; use `streaming: true` (for SFT) or `pretraining_dataset` (for pretraining) instead.
13-15: Be explicit about dataset slugs and config knobs (pretrain_multipack_attn, buffer).Names the knobs users will look for and aligns phrasing with the examples/config.
-Demonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset -with SmolLM2-135M. +Demonstrates streaming configuration for pretraining using HuggingFaceFW/fineweb-edu +with SmolLM2-135M.-- Uses `pretraining_dataset` configuration for automatic streaming -- Multipack attention control to prevent cross-attention between packed sequences -- Buffer size configuration for memory management +- Uses `pretraining_dataset` for automatic streaming +- Enable `pretrain_multipack_attn` to prevent cross-sample attention when packing +- Tune buffer size (`streaming_multipack_buffer_size`) for memory/packing trade-offsAlso applies to: 16-18
35-37: Avoid hard-coding the default buffer size unless guaranteed.If the code’s default changes, this doc goes stale. Either remove the default or verify it matches the implementation.
-- Controls buffer size for sample packing (default: 10,000) +- Controls buffer size for sample packing
43-46: Tighten grammar forsample_packingbullets.-- Packs multiple samples into single sequences -- Minimize per-step padding tokens +- Packs multiple samples into a single sequence +- Minimizes per-step padding tokens
47-50: Add actionable performance tips for low-memory environments.Surface concrete mitigations users will reach for when hitting OOM or slowness.
- Download small / frequently-used datasets locally for better performance - Larger buffer sizes improve packing efficiency +- If memory is constrained or you hit OOM, reduce `streaming_multipack_buffer_size` or disable `shuffle_merged_datasets` +- Remote datasets may benefit from local caching to avoid network bottlenecks
6-7: Cross-reference the detailed docs page.Helps users discover the full guidance added in this PR.
## Examples + +For a deeper dive, see `docs/streaming.qmd` in the repository.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
examples/streaming/README.md(1 hunks)
🧰 Additional context used
🪛 LanguageTool
examples/streaming/README.md
[grammar] ~3-~3: There might be a mistake here.
Context: ...ns for using Axolotl's streaming dataset functionality, which enables memory-effi...
(QB_NEW_EN)
[grammar] ~8-~8: There might be a mistake here.
Context: ...l train examples/streaming/sft.yaml; no axolotl preprocess` required! ### Pret...
(QB_NEW_EN)
[grammar] ~13-~13: There might be a mistake here.
Context: ...ning tasks using the fineweb-edu dataset with SmolLM2-135M. - Uses `pretraining_...
(QB_NEW_EN)
[grammar] ~24-~24: There might be a mistake here.
Context: ... streaming: true flag for SFT datasets - Evaluation dataset handling with streami...
(QB_NEW_EN)
[grammar] ~25-~25: There might be a mistake here.
Context: ...aluation dataset handling with streaming - Memory-efficient training on instruction...
(QB_NEW_EN)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: preview
There was a problem hiding this comment.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
examples/streaming/README.md (1)
51-51: Fix truncated line at end of fileThere’s a stray line “51” that will render incorrectly. Remove it to avoid a formatting artifact.
-51 +src/axolotl/utils/data/sft.py (2)
212-219: Preserve prompt style by parsing dataset type before partializing get_dataset_wrapper.get_dataset_wrapper in the non-streaming path receives both base_type and prompt_style (parsed via _parse_dataset_type). Here only dataset_base_type is provided; prompt style will be lost for types like "sft:chatml".
Apply:
- dataset_wrapper_partial = functools.partial( - get_dataset_wrapper, - dataset_config=pretraining_config, - tokenizer=tokenizer, - cfg=cfg, - dataset_base_type=pretraining_config["type"], - ) + d_base_type, d_prompt_style = _parse_dataset_type(pretraining_config.get("type")) + dataset_wrapper_partial = functools.partial( + get_dataset_wrapper, + dataset_config=pretraining_config, + tokenizer=tokenizer, + cfg=cfg, + dataset_base_type=d_base_type, + dataset_prompt_style=d_prompt_style, + )
228-235: Bug: streaming loader bypasses auth and shared loader logic.Directly calling datasets.load_dataset omits cfg.hf_use_auth_token (or token) and any central config logic in load_dataset_with_config. Private datasets will fail.
Apply:
- iter_dataset = load_dataset( - pretraining_config["path"], - streaming=True, - split=pretraining_config["split"], - name=pretraining_config["name"], - data_files=pretraining_config["data_files"], - ) + ds = load_dataset_with_config( + pretraining_config, + cfg.hf_use_auth_token, + streaming=True, + ) + if isinstance(ds, (DatasetDict, IterableDatasetDict)): + split = pretraining_config.get("split") or "train" + iter_dataset = ds[split] + else: + iter_dataset = dsI can add a regression test that uses a dummy private dataset gate to ensure token propagation on streaming paths.
♻️ Duplicate comments (2)
examples/streaming/README.md (1)
20-27: SFT docs: mention pretrain_multipack_attn and clarify eval is not streamed (with local-caching tip)Align with project docs/tests: eval is not streamed; recommend local caching. Also call out the attention mask flag for SFT with packing.
- Explicit `streaming: true` flag for SFT datasets -- Memory-efficient training on instruction datasets -- Evaluation datasets are currently not streamed +- Enable `pretrain_multipack_attn` when using `sample_packing` to avoid cross-sample attention +- Memory-efficient training on instruction datasets +- Evaluation datasets are not streamed; prefer smaller eval sets or local cachingsrc/axolotl/utils/data/sft.py (1)
141-146: Fix: Normalize cfg.datasets[0] to a DictDefault and ensure type/split are present.As previously flagged, directly doing DictDefault(cfg.datasets[0]) breaks when the first entry is not a plain dict (e.g., Pydantic model). Also, missing type leads to None flowing into get_dataset_wrapper later. Normalize the shape and guard attributes.
Apply:
- dataset_config = DictDefault(cfg.datasets[0]) - - # Ensure we have a split set - default to 'train' if not specified - if not hasattr(dataset_config, "split") or not dataset_config.split: - dataset_config.split = "train" + first = cfg.datasets[0] + # Normalize to a DictDefault mapping regardless of source type + if hasattr(first, "model_dump"): + dataset_config = DictDefault(first.model_dump()) + elif isinstance(first, dict): + dataset_config = DictDefault(first) + else: + dataset_config = DictDefault(dict(first)) + + # Ensure `type` is present + if not getattr(dataset_config, "type", None): + dataset_config.type = getattr(first, "type", dataset_config.get("type")) + + # Ensure split default + if not getattr(dataset_config, "split", None): + dataset_config.split = "train"Happy to wire a small unit test covering dict, pydantic, and namedtuple inputs.
🧹 Nitpick comments (13)
examples/streaming/README.md (7)
8-10: Clarify preprocess deprecation and fix punctuation around “e.g.”Make it explicit that preprocess is no longer needed due to deprecation of the iterable mode, and fix the “e.g.” punctuation.
-Run the following examples with e.g. `axolotl train examples/streaming/sft.yaml`; no -`axolotl preprocess` required! +Run the following examples with, e.g., `axolotl train examples/streaming/sft.yaml`. No +`axolotl preprocess` step is required (the `preprocess --iterable` mode is deprecated).
13-15: Capitalize dataset name for consistency“FineWeb-Edu” is typically capitalized; minor readability polish.
-Demonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset -with SmolLM2-135M. +Demonstrates streaming configuration for pretraining tasks using the FineWeb‑Edu dataset +with SmolLM2‑135M.
16-19: Name the exact flag: use “pretrain_multipack_attn”Call out the specific config knob to avoid ambiguity, matching the PR’s behavior change.
-- Multipack attention control to prevent cross-attention between packed sequences +- Enable `pretrain_multipack_attn` to prevent cross-sample attention when packing
34-38: Optional: note scope of the buffer (CPU RAM) and tuning guidanceA brief qualifier helps users plan memory; safe to keep concise.
- Controls buffer size for sample packing (default: 10,000) -- Larger values improve packing efficiency but use more memory +- Larger values improve packing efficiency but use more CPU RAM - Adjust based on available memory
39-42: Add reproducibility note for shufflingSuggest mentioning
shuffle_seedto aid deterministic experiments.- Enables shuffling of streaming datasets -- Requires additional memory for shuffle buffer +- Requires additional memory for shuffle buffer; set `shuffle_seed` for reproducibility
43-46: Reinforce attention masking requirement when packingPlace the guidance near the option users will toggle.
### `sample_packing` - Packs multiple samples into single sequences - Minimize per-step padding tokens +- When packing, set `pretrain_multipack_attn: true` to block attention across sample boundaries
49-50: Broaden performance tips with eval caching and buffer tuning adviceThese two bullets capture common pitfalls in streaming runs.
- Download small / frequently-used datasets locally for better performance -- Larger buffer sizes improve packing efficiency +- Larger buffer sizes improve packing efficiency +- Cache evaluation datasets locally to eliminate network variability +- Sweep `streaming_multipack_buffer_size` (e.g., 2k/5k/10k) to find the best tradeoff for your RAMsrc/axolotl/utils/data/sft.py (6)
63-66: Confirm routing: pretraining_dataset forces the streaming path, even if cfg.streaming is False.Is this intentional? Older configs may set pretraining_dataset for other flows. If the intention is “streaming whenever pretraining_dataset is configured,” keep as-is; otherwise, guard with cfg.streaming to avoid surprising behavior.
If you want to require explicit streaming, consider:
- if cfg.streaming or cfg.pretraining_dataset: + if cfg.streaming and cfg.pretraining_dataset: return _prepare_streaming_dataset(cfg, tokenizer, processor)
162-168: Evaluation in streaming mode currently disables iterable preprocessing; consider making it opt-in.Hard-coding preprocess_iterable=False means test_datasets won’t stream, which can be surprising and memory-heavy for large eval sets. Propagate a config flag (e.g., cfg.eval_streaming) or default to cfg.streaming.
Suggested tweak:
- _, eval_dataset, _ = _load_and_prepare_datasets( + _, eval_dataset, _ = _load_and_prepare_datasets( tokenizer, cfg, split="test", processor=processor, - preprocess_iterable=False, + preprocess_iterable=getattr(cfg, "eval_streaming", cfg.streaming), )
171-172: Deduplication log clarity.Message is accurate: exact dedup isn’t available with streaming. Consider hinting at the non-streaming alternative (preprocess offline) for users encountering this log.
173-175: Prompters list is empty in some streaming paths—verify downstream expectations.Standard paths return a non-empty prompters list; here you return [] (and earlier return a real list only in the non-packed streaming branch). If any trainer code assumes length to match datasets, this could break formatting/metrics hooks.
Option: thread the dataset_prompter from get_dataset_wrapper through wrap_streaming_dataset (or return a [None] sentinel instead of []) for API consistency.
210-210: Docstring mismatch: function now serves SFT as well.The docstring says “for pretraining” but this function is used by SFT when sample_packing is enabled.
Apply:
- """Load and prepare a streaming dataset for pretraining.""" + """Load and prepare a streaming dataset (SFT or pretraining)."""
253-260: Optional: avoid temp file leakage for placeholder dataset.NamedTemporaryFile(delete=False) never unlinks, leaving artifacts. If dispatch_batches ensures non-main ranks don’t iterate, a generator-based IterableDataset avoids touching disk.
Example (no diff since it changes implementation detail):
from datasets import IterableDataset def _create_placeholder_dataset() -> IterableDataset: def gen(): yield {"text": "lorem ipsum dolor sit amet"} return IterableDataset.from_generator(gen)If you keep the CSV, consider os.unlink(f.name) after constructing the dataset (but ensure the dataset doesn’t need the file thereafter).
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
docs/streaming.qmd(1 hunks)examples/streaming/README.md(1 hunks)examples/streaming/pretrain.yaml(1 hunks)examples/streaming/sft.yaml(1 hunks)src/axolotl/utils/data/sft.py(8 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- examples/streaming/pretrain.yaml
- examples/streaming/sft.yaml
- docs/streaming.qmd
🧰 Additional context used
🧬 Code graph analysis (1)
src/axolotl/utils/data/sft.py (2)
src/axolotl/utils/data/pretraining.py (1)
wrap_streaming_dataset(179-244)src/axolotl/utils/dict.py (1)
DictDefault(6-38)
🪛 LanguageTool
examples/streaming/README.md
[grammar] ~3-~3: There might be a mistake here.
Context: ...ns for using Axolotl's streaming dataset functionality, which enables memory-effi...
(QB_NEW_EN)
[grammar] ~8-~8: There might be a mistake here.
Context: ...l train examples/streaming/sft.yaml; no axolotl preprocess` required! ### Pret...
(QB_NEW_EN)
[grammar] ~13-~13: There might be a mistake here.
Context: ...ning tasks using the fineweb-edu dataset with SmolLM2-135M. - Uses `pretraining_...
(QB_NEW_EN)
[grammar] ~24-~24: There might be a mistake here.
Context: ... streaming: true flag for SFT datasets - Memory-efficient training on instruction...
(QB_NEW_EN)
[grammar] ~25-~25: There might be a mistake here.
Context: ...ficient training on instruction datasets - Evaluation datasets are currently not st...
(QB_NEW_EN)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: preview
🔇 Additional comments (6)
examples/streaming/README.md (1)
30-33: LGTM: streaming option description is clear and accurateAccurately reflects implicit enablement via
pretraining_dataset.src/axolotl/utils/data/sft.py (5)
12-13: Good addition: IterableDatasetDict support for split selection.Importing IterableDatasetDict enables correct split handling for streaming dataset dicts below. This aligns with your later isinstance checks.
19-19: Importing wrap_streaming_dataset is appropriate for SFT streaming.This consolidates the streaming encode/wrap logic and keeps SFT aligned with the pretraining path.
148-157: Non-packed streaming path via preprocess_iterable=True looks correct.Returning early with total_num_steps=-1 (or cfg.max_steps) matches streaming semantics and avoids eager preprocessing.
348-355: Skipping long-seq handling under preprocess_iterable is correct.This avoids materializing large transformed datasets when streaming. Packing remains gated as expected.
384-389: Split selection now supports IterableDatasetDict—good catch.This aligns the streaming and non-streaming code paths and prevents KeyErrors on split resolution.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
src/axolotl/utils/data/sft.py (1)
139-151: Potential issue with unsafe attribute access.The code assumes
cfg.datasets[0]has atypeattribute on line 146, which will fail ifcfg.datasets[0]is a plain dict. This needs defensive handling.Apply this fix to safely handle both dict and object types:
- dataset_config = DictDefault(cfg.datasets[0]) + # Safely extract and normalize the first dataset config + first_dataset = cfg.datasets[0] + if isinstance(first_dataset, dict): + dataset_config = DictDefault(first_dataset) + else: + dataset_config = DictDefault(first_dataset.model_dump() if hasattr(first_dataset, 'model_dump') else dict(first_dataset)) # Ensure we have a split set - default to 'train' if not specified if not hasattr(dataset_config, "split") or not dataset_config.split: dataset_config.split = "train"
🧹 Nitpick comments (2)
src/axolotl/utils/data/streaming.py (1)
1-1: Documentation clarification needed.The module docstring should be updated to reflect the expanded scope beyond just streaming, as these functions now handle both streaming and pretraining datasets.
-"""Data handling specific to streaming datasets.""" +"""Data handling for streaming and pretraining datasets with multipack support."""src/axolotl/utils/data/sft.py (1)
178-178: Consider using logging.warning for important user-facing messages.Since deduplication is silently skipped for streaming datasets, using
LOG.warninginstead ofLOG.infowould better alert users to this limitation.- LOG.info("Deduplication not available for streaming datasets") + LOG.warning("Deduplication not available for streaming datasets")
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (4)
src/axolotl/utils/data/__init__.py(2 hunks)src/axolotl/utils/data/sft.py(15 hunks)src/axolotl/utils/data/streaming.py(4 hunks)tests/test_streaming.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/test_streaming.py
- src/axolotl/utils/data/init.py
🧰 Additional context used
🧬 Code graph analysis (2)
src/axolotl/utils/data/streaming.py (1)
src/axolotl/utils/collators/batching.py (1)
PretrainingBatchSamplerDataCollatorForSeq2Seq(200-227)
src/axolotl/utils/data/sft.py (2)
src/axolotl/utils/data/streaming.py (1)
wrap_streaming_dataset(179-247)src/axolotl/utils/data/shared.py (2)
merge_datasets(525-565)load_dataset_with_config(93-148)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: preview
🔇 Additional comments (15)
src/axolotl/utils/data/streaming.py (6)
20-26: Function name accurately reflects the expanded functionality.Good job on renaming from
encode_pretrainingtoencode_streamingto better represent the function's expanded use case beyond just pretraining.
179-210: Great implementation of multipack attention logic for SFT datasets.The conditional logic correctly ensures that SFT datasets use
multipack_attn=Truefor proper attention isolation between packed sequences. The comment clearly explains the reasoning, and settingcfg.micro_batch_size = 1prevents double batching.
212-221: Clear documentation of unreachable code path.The comment effectively explains why this code path is not reachable for SFT datasets and points to the correct location in
sft.py.
224-226: Proper parameter naming for clarity.The change from generic
buffer_sizetostreaming_multipack_buffer_sizeimproves code clarity by making the parameter's purpose explicit.
250-257: Function renamed appropriately to match module pattern.The rename from
encode_packed_pretrainingtoencode_packed_streamingmaintains consistency with the other renamed functions.
285-286: Good cleanup - removed redundant deletion.The removal of the duplicate
del features["num_truncated_tokens"]line eliminates unnecessary redundancy.src/axolotl/utils/data/sft.py (9)
12-12: Necessary import for streaming dataset support.The addition of
IterableDatasetDictimport properly supports the new streaming functionality.
19-19: Import updated to match renamed module functions.The import correctly reflects the renamed
wrap_streaming_datasetfunction.
63-65: Clean routing logic for streaming vs standard datasets.The conditional routing based on
cfg.streamingorcfg.pretraining_datasetis clear and straightforward.
125-183: Well-structured streaming dataset preparation function.The function properly handles both pretraining and SFT streaming cases with appropriate error handling for unsupported configurations. The early return for non-packed streaming datasets is efficient.
214-257: Clean implementation of streaming dataset loader.The function properly handles dataset loading with appropriate placeholder creation for non-main processes and correct application of skip logic.
274-274: Parameter rename improves clarity.The rename from
preprocess_iterabletostreamingbetter reflects the parameter's purpose in the context of streaming dataset support.Also applies to: 283-283, 314-314, 326-326, 347-347, 379-379
355-355: Proper conditional logic for non-streaming operations.The condition correctly prevents long sequence handling and dataset saving for streaming datasets, which is appropriate since streaming datasets are processed on-the-fly.
391-391: Type checking expanded for streaming support.The addition of
IterableDatasetDictto the type check properly handles streaming dataset dictionaries.
384-384: Streaming parameter properly propagated.The addition of the
streamingparameter toload_dataset_with_configensures consistent behavior throughout the loading pipeline.
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (4)
docs/streaming.qmd (4)
67-71: Spell out behavior for multiple datasets and determinismClarify that shuffling is over the merged stream, and call out seed requirements for reproducibility (important for streaming).
Apply this diff:
-When enabled, shuffles the streaming dataset using the buffer. This requires additional -memory for the shuffle buffer. +When enabled, shuffles the merged streaming dataset using the buffer (not per-dataset). +This requires additional memory for the shuffle buffer. For reproducibility across runs +and distributed workers, set a training seed.
47-57: Optional: surface the SFT attention isolation in the SFT exampleEven if force-enabled, a short comment in the SFT snippet helps users discover the behavior in the most relevant place.
Apply this diff:
streaming: true datasets: - path: tatsu-lab/alpaca type: alpaca split: train # Optionally, enable sample packing streaming_multipack_buffer_size: 10000 sample_packing: true +# Note: In SFT, attention isolation between packed samples is automatically enabled.
17-18: Add a short migration note for removed CLI flag and renamed optionsUsers migrating from preprocess-time iterable mode will look for a canonical mapping in this page.
Apply this diff to add a migration subsection right after the “Configuration” header:
## Configuration +### Migration notes + +- The `axolotl preprocess --iterable` CLI flag has been removed; streaming now flows through the training path. +- `pretrain_multipack_buffer_size` has been renamed to `streaming_multipack_buffer_size`. +- In SFT, attention isolation for packed samples is force-enabled when `sample_packing: true`; you do not need to set `pretrain_multipack_attn` for SFT. +
63-66: Reference schema for default multipack buffer size and offer a starting valueThe docs should avoid hard-coding the default and instead point to the config schema, while still giving “10,000” as a practical starting point (default confirmed as 10_000 in the schema).
- File:
docs/streaming.qmd- Lines: 63–66
Proposed diff:
-Controls the buffer size for multipack streaming (default: 10,000). This determines how -many samples are buffered before packing. Larger buffers can improve packing efficiency -but use more memory. +Controls the buffer size for multipack streaming. This determines how many samples are +buffered before packing. Larger buffers can improve packing efficiency but use more host +memory. The default is defined in the config schema (a good starting point is 10,000).Verification:
- In
src/axolotl/utils/schemas/config.py(line 508):
streaming_multipack_buffer_size: int | None = 10_000- Examples in
examples/streaming/*.yamlalso setstreaming_multipack_buffer_size: 10000
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
docs/streaming.qmd(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
🔇 Additional comments (4)
docs/streaming.qmd (4)
7-16: Nice concise intro and “when to use” guidanceClear and actionable rationale for streaming, including the immediate-start benefit and remote/local parity.
86-86: Link Verification Complete
The filedocs/multipack.qmdexists alongsidedocs/streaming.qmd, so the relative link[documentation](multipack.qmd)will resolve correctly. No changes required.
111-114: Update example references to clickable linksWe’ve verified that the files
examples/streaming/pretrain.yamlandexamples/streaming/sft.yamlexist in the repo and are reachable via../examples/streaming/...fromdocs/streaming.qmd. Assuming Quarto will preserve these as static assets, you can turn them into clickable links:• File checked:
• examples/streaming/pretrain.yaml
• examples/streaming/sft.yamlReplace the block at lines 111–114 in
docs/streaming.qmdwith the following diff:@@ docs/streaming.qmd -See the `examples/streaming/` directory for complete configuration examples: -- `pretrain.yaml`: Pretraining with streaming dataset -- `sft.yaml`: Supervised fine-tuning with streaming +See the examples in the repository: +- [examples/streaming/pretrain.yaml](../examples/streaming/pretrain.yaml): Pretraining with streaming dataset +- [examples/streaming/sft.yaml](../examples/streaming/sft.yaml): Supervised fine-tuning with streamingIf your site build cannot resolve repo-relative links, you can revert to the plain-text list.
31-41: Validation Completed: pretraining snippet keys are correctThe
PretrainingDatasetschema in src/axolotl/utils/schemas/datasets.py defines:
type: str | None = "pretrain"split: str | None = "train"text_column: str | None = "text"These match the YAML snippet’s
type: pretrain,split, andtext_columnkeys. No changes needed.
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
src/axolotl/utils/collators/__init__.py (1)
1-1: Capitalize proper noun in docstring ("Mamba")Minor clarity/style: capitalize the model name.
-"""Shared axolotl collators for multipacking, mamba, multimodal.""" +"""Shared axolotl collators for multipacking, Mamba, multimodal."""src/axolotl/utils/schemas/validation.py (2)
63-76: Deprecation warning reads well; consider one-shot logging and pinning version in the message.
- To avoid log spam in distributed runs, prefer LOG.warning_once here (used elsewhere in this file) or pass main_process_only=True if supported by your logger wrapper.
- The TODO mentions 0.13.0, but the user-facing message doesn’t. If the target is firm, include it in the warning to make the migration window explicit.
Apply this minimal tweak:
- LOG.warning( - "Setting `pretraining_dataset` without explicitly setting `streaming: " - "true` is deprecated. In a future release, streaming will not be " + LOG.warning_once( + "Setting `pretraining_dataset` without explicitly setting `streaming: " + "true` is deprecated and will change in Axolotl 0.13.0. Streaming will not be " "automatically enabled when using pretraining_dataset. Please " "explicitly set `streaming: true` in your configuration to maintain " "current behavior." )
1112-1141: Validation rules are correct; tighten messages to reflect “non-zero” and align with streaming semantics.
- The current condition treats 0 as allowed (good), but the error text reads like “not supported at all”. Clarify “non-zero val_set_size”.
- The “max_steps must be set when using streaming datasets” rule is right. Consider hinting that num_train_epochs is insufficient for iterable datasets.
Apply these small message tweaks:
- raise ValueError( - "val_set_size is not supported with pretraining_dataset. " - "Use test_datasets to specify evaluation datasets for pretraining." - ) + raise ValueError( + "Non-zero val_set_size is not supported with pretraining_dataset. " + "Use test_datasets to specify evaluation datasets for pretraining." + ) ... - raise ValueError( - "val_set_size is not supported with streaming datasets. " - "Use test_datasets to specify evaluation datasets when streaming is enabled." - ) + raise ValueError( + "Non-zero val_set_size is not supported with streaming datasets. " + "Use test_datasets to specify evaluation datasets when streaming is enabled." + ) ... - raise ValueError( - "max_steps must be set when using streaming datasets. " - "Trainer cannot infer dataset length for iterable datasets." - ) + raise ValueError( + "max_steps must be set when using streaming datasets. " + "Trainer cannot infer dataset length for iterable datasets; " + "num_train_epochs is insufficient." + )If helpful, I can add unit tests asserting:
- val_set_size=None and =0 both pass with streaming/pretraining, but >0 raises.
- streaming_multipack_buffer_size=0 is preserved (no unintended migration).
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
src/axolotl/utils/collators/__init__.py(1 hunks)src/axolotl/utils/schemas/validation.py(3 hunks)tests/test_streaming.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/test_streaming.py
🧰 Additional context used
🧬 Code graph analysis (1)
src/axolotl/utils/collators/__init__.py (2)
src/axolotl/utils/collators/batching.py (4)
BatchSamplerDataCollatorForSeq2Seq(129-155)DataCollatorForSeq2Seq(12-125)PretrainingBatchSamplerDataCollatorForSeq2Seq(200-227)V2BatchSamplerDataCollatorForSeq2Seq(159-196)src/axolotl/utils/collators/mamba.py (1)
MambaDataCollator(15-39)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: preview
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
🔇 Additional comments (2)
src/axolotl/utils/collators/__init__.py (2)
3-9: Re-export imports look good and align with the packaging intentThe curated import surface is clean and avoids circulars. No issues spotted.
11-17: No star-imports detected; addlist[str]annotation to__all__All consumer code imports the exact names you’ve listed (no
import *usages were found), so defining__all__won’t shrink any existing API surface. To satisfy strict type-checkers, annotate__all__with a type hint:--- a/src/axolotl/utils/collators/__init__.py +++ b/src/axolotl/utils/collators/__init__.py @@ lines 11-17 -__all__ = [ +__all__: list[str] = [ "DataCollatorForSeq2Seq", "BatchSamplerDataCollatorForSeq2Seq", "V2BatchSamplerDataCollatorForSeq2Seq", "PretrainingBatchSamplerDataCollatorForSeq2Seq", "MambaDataCollator", ]With this change, the public API remains intact and your code is fully type-annotated.
NanoCode012
left a comment
There was a problem hiding this comment.
Nice work refactoring out preprocess_iterable.
68f1a5c to
528070b
Compare
Description
Streaming datasets should not be limited to pretraining only. This PR changes that.
Also removed
--iterablepreprocess CLI arg since it was confusing / maybe not advisable to do.This is a follow-up to #3087 which is smaller in scope since things were broken / getting complicated in that PR.
TODO:
preprocess --iterabledeprecation (from Better support for streaming datasets; multidataset weighting / round robin #3087)sample_packing/pretrain_multipack_attnoddnessFollow-ups:
pretraining_datasetconfig--iterableCLI argMotivation and Context
Some folks have been asking for better dataset streaming support. This is part of a larger effort to refactor / improve the logic around how we handle datasets / dataloading.
How has this been tested?
Summary by CodeRabbit
New Features
Documentation
Bug Fixes
Chores
Tests