Skip to content

Better support for streaming datasets; multidataset weighting / round robin#3087

Closed
djsaunde wants to merge 23 commits into
mainfrom
streaming
Closed

Better support for streaming datasets; multidataset weighting / round robin#3087
djsaunde wants to merge 23 commits into
mainfrom
streaming

Conversation

@djsaunde

@djsaunde djsaunde commented Aug 20, 2025

Copy link
Copy Markdown
Collaborator

Description

Streaming datasets should not be limited to pretraining only. This PR changes that, and also adds support for different sampling strategies (round robin and sampling according to weights). Also, preprocessing iterable datasets has been disallowed by this PR.

Motivation and Context

Some folks have been asking for better dataset streaming support. This is part of a larger effort to refactor / improve the logic around how we handle datasets / dataloading.

How has this been tested?

  • A few tests (need more)
  • A few manual runs

Summary by CodeRabbit

  • New Features

    • Streaming dataset support with per-split handling, sharding, on-demand loading, and dataset mixing strategies (round_robin, weighted, random, concatenate).
  • Bug Fixes / Validation

    • Stronger config validation for streaming, max_steps/num_epochs conflicts, and mixing-weights consistency.
  • Documentation

    • Enhanced schema descriptions for epochs, streaming, and mixing options.
  • Refactor

    • Removed legacy dataset-packing/streaming preprocess path and simplified preprocessing flow.
  • Tests

    • Added unit and end-to-end tests covering streaming and dataset-mixing behaviors.
  • Chores

    • Added config fields: streaming, dataset_mixing_strategy, mixing_weights; removed preprocess_iterable.

@djsaunde djsaunde requested a review from winglian August 20, 2025 06:09
@djsaunde djsaunde self-assigned this Aug 20, 2025
@coderabbitai

coderabbitai Bot commented Aug 20, 2025

Copy link
Copy Markdown
Contributor

Warning

Rate limit exceeded

@djsaunde has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 8 minutes and 23 seconds before requesting another review.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

📥 Commits

Reviewing files that changed from the base of the PR and between 603c79b and 78a039e.

📒 Files selected for processing (15)
  • src/axolotl/cli/args.py (1 hunks)
  • src/axolotl/cli/preprocess.py (1 hunks)
  • src/axolotl/common/datasets.py (0 hunks)
  • src/axolotl/datasets.py (4 hunks)
  • src/axolotl/utils/data/sft.py (9 hunks)
  • src/axolotl/utils/data/shared.py (4 hunks)
  • src/axolotl/utils/data/utils.py (1 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • src/axolotl/utils/schemas/training.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (6 hunks)
  • src/axolotl/utils/trainer.py (4 hunks)
  • tests/e2e/integrations/test_kd.py (1 hunks)
  • tests/e2e/test_streaming.py (1 hunks)
  • tests/test_datasets.py (4 hunks)
  • tests/test_packed_dataset.py (0 hunks)
📝 Walkthrough

Walkthrough

Removes the preprocess_iterable CLI/config flag and shifts to streaming-first dataset support across loading, preparation, merging, tokenization, validation, and trainer paths; adds dataset mixing strategies (concatenate, round_robin, weighted, random); updates tests and e2e flows for streaming. (≤50 words)

Changes

Cohort / File(s) Summary
CLI flag removal
src/axolotl/cli/args.py
Removes PreprocessCliArgs.iterable.
Config & validation
src/axolotl/utils/schemas/config.py, src/axolotl/utils/schemas/training.py, src/axolotl/utils/schemas/validation.py
Adds streaming, dataset_mixing_strategy, mixing_weights; removes preprocess_iterable; adds streaming-related validators and training/pretraining checks; documents num_epochs; ValidationMixin now includes StreamingValidationMixin.
Data loading / streaming refactor
src/axolotl/utils/data/sft.py, src/axolotl/common/datasets.py
Removes preprocess_iterable from signatures/calls; adds _is_streaming_enabled; streaming-aware flows returning `Dataset
Dataset merging strategies
src/axolotl/utils/data/shared.py
merge_datasets accepts `Dataset
Tokenization & wrappers
src/axolotl/datasets.py
TokenizedPromptDataset.process now typed as `Dataset
Long-sequence guards
src/axolotl/utils/data/utils.py
Short-circuits drop_long_seq when dataset lacks input_ids or is streaming; logs and returns early.
Trainer adjustments
src/axolotl/utils/trainer.py
Adds _create_filtered_iterable_dataset for streaming filtering; uses full train_dataset when calculating steps (no remove_columns(["length"])); minor comment/import cleanup.
Schemas: signature updates
src/axolotl/utils/data/sft.py (signatures)
Many function/method signatures updated to drop preprocess_iterable and accept/return `Dataset
Tests: unit & e2e
tests/test_datasets.py, tests/test_packed_dataset.py, tests/e2e/test_streaming.py, tests/e2e/integrations/test_kd.py
Add streaming fixtures/tests for SFT and mixing strategies; remove ConstantLengthDataset packing test; add e2e streaming tests; tweak KD e2e config (chat_template change).

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • Feat: add devstral model support #2880 — Modifies TokenizedPromptDataset.process and tokenization/multiprocessing behavior; likely overlaps in tokenization/streaming handling.
  • Various fixes for VLMs #3063 — Changes TokenizedPromptDataset and wrap_dataset_for_tokenized_prompt around features/remove_columns and map kwargs; strongly related to these adjustments.

Suggested reviewers

  • winglian
  • NanoCode012
  • SalmanMohammadi
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch streaming

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@djsaunde djsaunde marked this pull request as ready for review August 20, 2025 15:33
@djsaunde

Copy link
Copy Markdown
Collaborator Author

Currently seeing an error when using eval streaming datasets, will need to debug

@github-actions

github-actions Bot commented Aug 20, 2025

Copy link
Copy Markdown
Contributor

📖 Documentation Preview: https://68a89608cd414f80a189f3f0--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit a7edc77

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/axolotl/utils/data/shared.py (1)

576-612: Bug: cfg.get usage will break with Pydantic config; align default with config and validate weights.

  • In this module, cfg is typically a config object with attribute access; using cfg.get will raise if cfg is not Dict-like. Use attribute access or getattr.
  • Default strategy here is "concatenate", but config default is "round_robin". Align to avoid surprise behavior when the field is missing.
  • For "weighted", validate length and non-negativity and normalize if weights don't sum to 1.0.
-    strategy = cfg.get("dataset_mixing_strategy", "concatenate")
-    weights = cfg.get("mixing_weights", None)
+    strategy = getattr(cfg, "dataset_mixing_strategy", None) or "round_robin"
+    weights = getattr(cfg, "mixing_weights", None)
@@
-    if strategy == "weighted":
-        return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed)
+    if strategy == "weighted":
+        if not weights or len(weights) != len(datasets):
+            raise ValueError(
+                "mixing_weights must be provided and match the number of datasets when "
+                "dataset_mixing_strategy='weighted'"
+            )
+        if min(weights) < 0:
+            raise ValueError("mixing_weights must be non-negative")
+        total = sum(weights)
+        # Normalize if needed
+        norm_weights = [w / total for w in weights] if total > 0 else weights
+        return interleave_datasets(datasets, probabilities=norm_weights, seed=cfg.seed)

Add this import at the top of the file if you choose to use math.isclose for exactness instead of simple normalization:

import math

Alternatively, keep simple normalization as shown above to avoid extra imports.

🧹 Nitpick comments (12)
src/axolotl/utils/schemas/training.py (1)

164-169: Good JSON-schema enrichment for num_epochs. Consider clarifying streaming interplay.

Since this PR introduces streaming, it helps to state that max_steps typically governs iteration when streaming is enabled.

-    num_epochs: float = Field(
-        default=1.0,
-        json_schema_extra={
-            "description": "Number of iterations over dataset for training"
-        },
-    )
+    num_epochs: float = Field(
+        default=1.0,
+        json_schema_extra={
+            "description": "Number of passes over the dataset for training. When using streaming datasets, training is typically governed by `max_steps` (recommended to set) and `num_epochs` may not correspond to a full pass over data."
+        },
+    )
src/axolotl/datasets.py (3)

49-85: Avoid large batched mapping on streaming to prevent memory spikes; also prefer column_names for remove_columns.

  • Setting batch_size for IterableDataset can accumulate large in-memory batches; better to only set batch_size for non-streaming datasets.
  • Using dataset.column_names (or list(features)) is a safer/clearer input for remove_columns than a keys view.
-        features = None
-        if not isinstance(dataset, IterableDataset):
-            features = dataset.features.keys()
-
-        map_kwargs: dict[str, Any] = {}
-        if self.prompt_tokenizer.supports_batched:
-            map_kwargs["batched"] = True
-            map_kwargs["batch_size"] = 1_000
+        features = None
+        map_kwargs: dict[str, Any] = {}
+        if not isinstance(dataset, IterableDataset):
+            # Prefer column_names; ensure it's a concrete list
+            features = list(dataset.column_names)
+            if self.prompt_tokenizer.supports_batched:
+                map_kwargs["batched"] = True
+                map_kwargs["batch_size"] = 1_000
+        else:
+            if self.prompt_tokenizer.supports_batched:
+                # For streaming, enable batched but omit batch_size to avoid large in-memory batches
+                map_kwargs["batched"] = True

80-85: remove_columns wiring is correct; minor type-safety nit.

Since features can be a keys view in some contexts, ensure it's a list (covered by the previous suggestion). Otherwise, this looks good.


108-110: If ConstantLengthDataset is test-only, move it under tests to reduce public surface.

If nothing in src imports this except tests, consider relocating it to tests and importing locally in the test to avoid shipping unused runtime code.

Would you like me to draft the test-side helper and remove this class from the package code?

src/axolotl/utils/schemas/config.py (1)

935-970: New streaming and mixing config fields are well-scoped; consider stricter typing and validation hints.

  • dataset_mixing_strategy and eval_dataset_mixing_strategy could benefit from Literal types for schema/validation.
  • Ensure validations enforce:
    • mixing_weights length equals number of datasets, non-negative, and sum to 1.0 (or normalized consistently).
    • When streaming=True (or eval_streaming=True), max_steps is provided (as documented).
-    dataset_mixing_strategy: str | None = Field(
+    dataset_mixing_strategy: Literal["concatenate", "round_robin", "weighted", "random"] | None = Field(
         default="round_robin",
         json_schema_extra={
             "description": "Strategy for mixing multiple datasets: 'concatenate', 'round_robin' (equal sampling), 'weighted' (use mixing_weights), or 'random' (random sampling with equal probability). Works for both streaming and non-streaming datasets."
         },
     )
@@
-    eval_dataset_mixing_strategy: str | None = Field(
+    eval_dataset_mixing_strategy: Literal["concatenate", "round_robin", "weighted", "random"] | None = Field(
         default=None,
         json_schema_extra={
             "description": "Strategy for mixing multiple evaluation datasets. If not set, falls back to dataset_mixing_strategy. Options: 'concatenate', 'round_robin', 'weighted', 'random'."
         },
     )

If validations are already added in utils/schemas/validation.py, ignore this and consider marking them in the docstrings for discoverability.

src/axolotl/utils/schemas/validation.py (6)

519-531: Defaulting num_epochs to 1 is good; prefer int over float

Setting a default is helpful. Consider using an int to avoid subtle type inconsistencies downstream.

-        if max_steps is None and num_epochs is None:
-            data["num_epochs"] = 1.0
+        if max_steps is None and num_epochs is None:
+            data["num_epochs"] = 1

532-548: Combine nested conditions (ruff SIM102)

Simplify the nested if to a single condition for clarity.

-        if saves_per_epoch is not None:
-            # Check if saves_per_epoch is set but num_epochs is unset
-            if num_epochs is None:
-                raise ValueError(
-                    "saves_per_epoch requires num_epochs to be set to calculate save "
-                    "intervals."
-                )
+        if saves_per_epoch is not None and num_epochs is None:
+            raise ValueError(
+                "saves_per_epoch requires num_epochs to be set to calculate save intervals."
+            )

549-564: Combine nested conditions (ruff SIM102)

Mirror the change above for evals_per_epoch.

-        if evals_per_epoch is not None:
-            if num_epochs is None:
-                raise ValueError(
-                    "evals_per_epoch requires num_epochs to be set to calculate "
-                    "evaluation intervals."
-                )
+        if evals_per_epoch is not None and num_epochs is None:
+            raise ValueError(
+                "evals_per_epoch requires num_epochs to be set to calculate evaluation intervals."
+            )

1420-1431: Error message says “skipping” but you raise—reword to avoid confusion

You raise a ValueError, so “skipping” is misleading.

-                raise ValueError(
-                    "Validation splits not supported for streaming datasets, skipping"
-                )
+                raise ValueError(
+                    "Validation splits (val_set_size > 0) are not supported with streaming datasets."
+                )

1433-1442: Optional: avoid redundant checks (ruff SIM102-like)

Compute the “streaming enabled” condition once to avoid double-calling.

-        if self._is_streaming_enabled("train") or self._is_streaming_enabled("eval"):
-            if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
+        streaming_enabled = self._is_streaming_enabled("train") or self._is_streaming_enabled("eval")
+        if streaming_enabled and os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
                 raise ValueError("preprocess is not supported for streaming datasets")

1503-1548: Optional: use math.isclose for summation tolerance

Using math.isclose improves readability for the 1.0 sum check.

-            if abs(sum(weights) - 1.0) > 1e-6:
-                raise ValueError(f"{weights_field} must sum to 1.0, got {sum(weights)}")
+            import math
+            total = sum(weights)
+            if not math.isclose(total, 1.0, rel_tol=0.0, abs_tol=1e-6):
+                raise ValueError(f"{weights_field} must sum to 1.0, got {total}")
tests/test_datasets.py (1)

637-667: Strengthen assertion to verify both sources appear

len(set(sources)) >= 1 is always true. Check for both sources to ensure interleave correctness.

-        sources = [sample["source"] for sample in samples]
-        assert len(set(sources)) >= 1  # At least one unique source
+        sources = [sample["source"] for sample in samples]
+        # Expect samples from both streams
+        assert len(set(sources)) >= 2
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 06eaf6c and 49e528f.

📒 Files selected for processing (12)
  • src/axolotl/cli/args.py (0 hunks)
  • src/axolotl/common/datasets.py (0 hunks)
  • src/axolotl/datasets.py (4 hunks)
  • src/axolotl/utils/data/sft.py (8 hunks)
  • src/axolotl/utils/data/shared.py (4 hunks)
  • src/axolotl/utils/data/utils.py (1 hunks)
  • src/axolotl/utils/data/wrappers.py (0 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • src/axolotl/utils/schemas/training.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (5 hunks)
  • src/axolotl/utils/trainer.py (1 hunks)
  • tests/test_datasets.py (4 hunks)
💤 Files with no reviewable changes (3)
  • src/axolotl/common/datasets.py
  • src/axolotl/utils/data/wrappers.py
  • src/axolotl/cli/args.py
🧰 Additional context used
🧬 Code Graph Analysis (3)
tests/test_datasets.py (3)
src/axolotl/utils/data/sft.py (4)
  • _load_tokenized_prepared_datasets (292-351)
  • prepare_datasets (90-107)
  • _is_streaming_enabled_for_split (47-66)
  • _get_streaming_config_for_split (69-86)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/utils/data/shared.py (1)
  • _merge_datasets_with_strategy (576-611)
src/axolotl/utils/data/shared.py (1)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/utils/data/sft.py (4)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/utils/data/utils.py (2)
  • retry_on_request_exceptions (31-70)
  • deduplicate_and_log_datasets (112-148)
src/axolotl/utils/trainer.py (1)
  • calculate_total_num_steps (392-519)
src/axolotl/utils/data/shared.py (4)
  • generate_dataset_hash_from_config (506-525)
  • try_load_from_hub (486-503)
  • load_preprocessed_dataset (455-483)
  • save_preprocessed_dataset (407-452)
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/validation.py

539-541: Use a single if statement instead of nested if statements

(SIM102)


556-557: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


1437-1438: Use a single if statement instead of nested if statements

(SIM102)

🔇 Additional comments (31)
src/axolotl/utils/trainer.py (1)

550-553: Typo fix acknowledged; comment is now accurate.

The note clarifies the correct timing of distributed state init relative to ACCELERATE_USE_DEEPSPEED. No functional changes introduced.

src/axolotl/datasets.py (1)

1-11: Docstring and typing import look good.

Clearer module overview and Any import are appropriate for the added map/filter kwargs.

src/axolotl/utils/data/shared.py (2)

551-556: Pre-merge per-dataset shuffle is guarded correctly.

Only shuffles when all inputs are regular Datasets, avoiding IterableDataset pitfalls. Good.


557-573: Post-merge shuffle logic is sensible; good warning for curriculum_sampling.

Skips shuffle for streaming datasets and warns when curriculum sampling is on. Works as expected.

src/axolotl/utils/schemas/validation.py (8)

6-6: Import looks correct

Importing os is required for the new env-var checks below.


196-196: Pylint suppression is reasonable

Given the aggregation of many validators in this mixin, suppressing too-many-public-methods here is acceptable.


513-515: Nit: tighten the error message

The error message says “cannot be used together.” That’s fine, just noting the change is cosmetic. No action needed.


1388-1410: Nice addition: central streaming enablement logic

The per-context _is_streaming_enabled helper provides a clear source of truth and will reduce drift across call sites.


1411-1419: Requiring max_steps with streaming is correct

Trainer can’t infer length on streams; this validator is essential.


1444-1456: Auto-enforcing skip_prepare_dataset=True under streaming is good

The warning for explicit False is helpful; the auto-flip prevents expensive/invalid ops.


1458-1501: Mixing strategy/weights validation is solid

Covers invalid strategies, non-negative numeric weights, sum-to-1, and length matching. Good.


1561-1561: Good to include StreamingValidationMixin in ValidationMixin

Ensures streaming validations run by default across config usage.

src/axolotl/utils/data/sft.py (12)

12-12: Import IterableDatasetDict is necessary

You correctly handle both DatasetDict and IterableDatasetDict downstream.


47-67: Per-split streaming toggle looks correct

The precedence of eval_streaming (for test split), then streaming, then pretraining-default is sensible.


69-87: Eval-specific mixing overrides are handled cleanly

Using a shallow DictDefault copy to override only mixing fields is appropriate here.


106-108: Early branch to pretraining path is fine

Removes the old preprocess_iterable surface and simplifies prepare_datasets.


156-166: Correct: derive total steps from max_steps for streaming train datasets

Relying on cfg.max_steps for IterableDataset avoids unreliable step estimates. This assumes validators ensure max_steps is present when streaming.

If we want to harden this path for unvalidated cfg (e.g., in unit tests), we could add a defensive check:

  • If train_dataset is IterableDataset and cfg.max_steps is falsy, raise a descriptive error early.
    Do you want a patch for that?

310-351: Streaming vs. cached loading paths are well separated

The streaming path bypasses caching and relies on raw loader; non-streaming still benefits from hub/disk caching. Good design.


395-401: Skip persisting streaming datasets

Guarding save_preprocessed_dataset to only run for non-iterable datasets prevents invalid save attempts. Good catch.


420-420: Split selection supports IterableDatasetDict

Good addition to support dict-based iterable datasets.


465-472: Val split size parsing is clear

Covers both absolute (>1) and fractional (0-1] expressions.


482-487: Dedup correctly skipped for streaming datasets

This matches utils.handle_long_seq_in_dataset behavior for streams.


493-501: Consistent dedup behavior for eval split

Mirrors the train handling. Looks good.


506-527: Shard application works for both Dataset and IterableDataset

Good reuse via a shared helper, and logging shard selection is helpful.

tests/test_datasets.py (7)

27-28: OK to suppress too-many-public-methods for this test class

Keeps pylint quiet without hiding real issues.


50-67: Streaming fixture looks fine

Simple generator-based IterableDataset is appropriate for tests.


509-550: Good end-to-end streaming test

Covers type, step count, and basic sample structure. One note: since tests bypass schema validators, they rely on utils.handle_long_seq_in_dataset to gracefully no-op for streams (fixed above).


551-571: Strategy validation smoke test is reasonable

Exercising _merge_datasets_with_strategy across valid strategies is useful.


572-597: Round-robin mixing test looks good

Interleaving assertion is sensible.


597-637: Weighted mixing test is fine

Lightweight proportionality and representation checks make sense here.


684-706: Eval mixing overrides test is accurate

Confirms that eval-specific settings shadow the main ones for the test split.

Comment thread src/axolotl/utils/data/sft.py Outdated
Comment thread src/axolotl/utils/data/utils.py
Comment thread tests/test_datasets.py Outdated

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
src/axolotl/utils/data/sft.py (2)

210-236: Potential KeyError extracting pretraining config

Accessing config["name"], config["skip"], config["data_files"] will KeyError when omitted. Use .get with defaults.

Apply this diff:

         return DictDefault(
             {
                 "path": config["path"],
-                "name": config["name"],
-                "skip": config["skip"],
+                "name": config.get("name"),
+                "skip": config.get("skip", 0),
                 "split": config.get("split", "train"),
-                "data_files": config.get("data_files"),
+                "data_files": config.get("data_files"),
                 "type": config.get("type", "pretrain"),
             }
         )

252-267: Be defensive accessing dispatch_batches from accelerator_config

accelerator_config can be a dict or object; attribute access may fail if the key is absent. Use safe access to avoid AttributeError.

Apply this diff:

-    if (
-        cfg.accelerator_config
-        and cfg.accelerator_config.dispatch_batches
-        and not is_local_main_process()
-    ):
+    accel = cfg.accelerator_config
+    dispatch_batches = (
+        accel.get("dispatch_batches")
+        if isinstance(accel, dict)
+        else getattr(accel, "dispatch_batches", None)
+    ) if accel else None
+    if accel and dispatch_batches and not is_local_main_process():
         iter_dataset = _create_placeholder_dataset()
     else:
♻️ Duplicate comments (1)
src/axolotl/utils/data/sft.py (1)

415-423: Per-split streaming flag now passed to loader (addresses earlier review)

This fixes the previously reported bug where cfg.streaming was used indiscriminately for all splits. Thanks for addressing.

🧹 Nitpick comments (13)
src/axolotl/utils/schemas/validation.py (6)

519-531: Sensible default: auto-set num_epochs=1.0 when steps/epochs unset

Good call to set a sane default; it unblocks downstream logic in normalize_config that derives save/eval steps. Consider documenting this behavior in the user-facing config docs.


532-548: Flatten nested condition for clarity (SIM102)

You can simplify the nested if to a single condition.

Apply this diff:

-        if saves_per_epoch is not None:
-            # Check if saves_per_epoch is set but num_epochs is unset
-            if num_epochs is None:
-                raise ValueError(
-                    "saves_per_epoch requires num_epochs to be set to calculate save "
-                    "intervals."
-                )
+        # Check if saves_per_epoch is set but num_epochs is unset
+        if saves_per_epoch is not None and num_epochs is None:
+            raise ValueError(
+                "saves_per_epoch requires num_epochs to be set to calculate save "
+                "intervals."
+            )

549-564: Flatten nested condition for clarity (SIM102)

Same idea here; combine into a single if.

Apply this diff:

-        if evals_per_epoch is not None:
-            if num_epochs is None:
-                raise ValueError(
-                    "evals_per_epoch requires num_epochs to be set to calculate "
-                    "evaluation intervals."
-                )
+        if evals_per_epoch is not None and num_epochs is None:
+            raise ValueError(
+                "evals_per_epoch requires num_epochs to be set to calculate "
+                "evaluation intervals."
+            )

1445-1455: Error message says “skipping” but raises — adjust message

The exception text implies a soft skip; it actually hard-fails. Reword to avoid confusion.

Apply this diff:

-                raise ValueError(
-                    "Validation splits not supported for streaming datasets, skipping"
-                )
+                raise ValueError(
+                    "Validation splits are not supported with streaming datasets. Please set `val_set_size: 0` or use `test_datasets`."
+                )

1457-1465: Preprocess guard via env var is fine, but consider cfg flag too

This correctly blocks preprocess runs with streaming. Optional future enhancement: also respect a config-level is_preprocess (if present) to avoid relying solely on environment.


1558-1560: Use isclose for weight sum tolerance

Floating point sums can be finicky. Use math.isclose for robustness.

Apply this diff:

-            if abs(sum(weights) - 1.0) > 1e-6:
-                raise ValueError(f"{weights_field} must sum to 1.0, got {sum(weights)}")
+            import math
+            if not math.isclose(sum(weights), 1.0, rel_tol=1e-9, abs_tol=1e-6):
+                raise ValueError(f"{weights_field} must sum to 1.0, got {sum(weights)}")
tests/e2e/test_streaming.py (3)

29-65: E2E config likely requires GPU; add skip-if-no-GPU to stabilize CI

flash_attention=True and optimizer="adamw_torch_fused" typically require a CUDA environment and supported libs. To avoid sporadic CI failures on CPU jobs, gate these tests on GPU availability.

Apply this diff to add torch import and a class-level skip:

@@
-import pytest
+import pytest
+import torch
@@
-class TestStreamingDatasets:
+@pytest.mark.skipif(
+    not torch.cuda.is_available(),
+    reason="GPU required for flash-attn and fused optimizer in streaming E2E tests",
+)
+class TestStreamingDatasets:

Alternatively, drop flash_attention/fused optimizer in these smoke tests to allow CPU runs.


78-85: TensorBoard threshold is quite tight for smoke tests

A fixed <2.5 loss can be flaky across seeds/datasets. Consider relaxing or asserting presence of metrics rather than absolute value to reduce flakiness.


139-141: Eval cadence: eval_steps=3 with max_steps=3

This triggers eval only at end; OK. If you want to see intermediate eval as well, set eval_steps=1. Not a blocker.

src/axolotl/utils/data/sft.py (4)

69-86: Shallow-copied cfg for eval overrides — acceptable, but note mutability

DictDefault(cfg) creates a shallow copy; only top-level overrides are adjusted, which is fine for strategy/weights. If deeper eval-only overrides are added later, consider copy.deepcopy to avoid aliasing surprises.


148-155: Fix typo in comment

Minor typo in the comment makes it hard to read.

Apply this diff:

-    # Skip validation for streaming eval datasets since theWhat hy don't have a calculable length
+    # Skip validation for streaming eval datasets since they don't have a calculable length

298-304: Return type should include IterableDataset/IterableDatasetDict

When streaming, this function returns IterableDataset or IterableDatasetDict; reflect that in the type hint.

Apply this diff:

-) -> tuple[Dataset | DatasetDict, list[Prompter | None]]:
+) -> tuple[Dataset | IterableDataset | DatasetDict | IterableDatasetDict, list[Prompter | None]]:

360-367: Return type should include IterableDataset

The streaming case returns IterableDataset. Update the annotation accordingly.

Apply this diff:

-) -> tuple[Dataset, list[Prompter | None]]:
+) -> tuple[Dataset | IterableDataset, list[Prompter | None]]:
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 49e528f and c03e310.

📒 Files selected for processing (3)
  • src/axolotl/utils/data/sft.py (9 hunks)
  • src/axolotl/utils/schemas/validation.py (6 hunks)
  • tests/e2e/test_streaming.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (2)
tests/e2e/test_streaming.py (3)
src/axolotl/utils/config/__init__.py (2)
  • normalize_config (97-237)
  • validate_config (259-303)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
tests/e2e/utils.py (2)
  • check_model_output_exists (171-192)
  • check_tensorboard (149-168)
src/axolotl/utils/data/sft.py (4)
src/axolotl/utils/dict.py (1)
  • DictDefault (6-38)
src/axolotl/utils/data/utils.py (1)
  • deduplicate_and_log_datasets (112-148)
src/axolotl/utils/trainer.py (1)
  • calculate_total_num_steps (392-519)
src/axolotl/utils/data/shared.py (4)
  • try_load_from_hub (486-503)
  • load_preprocessed_dataset (455-483)
  • save_preprocessed_dataset (407-452)
  • load_dataset_with_config (94-149)
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/validation.py

539-541: Use a single if statement instead of nested if statements

(SIM102)


556-557: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


1461-1462: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: preview
🔇 Additional comments (14)
src/axolotl/utils/schemas/validation.py (5)

1130-1153: Accelerate split/dispatch defaults for streaming look correct

Ensuring split_batches/dispatch_batches=False when streaming (train or eval) is enabled avoids IterableDataset incompatibilities with batch splitting. Looks good as a defensive default.


1434-1443: Max-steps requirement for streaming is correct

Requiring max_steps when streaming prevents unknown-length dataset issues with schedulers/optimizers. Solid.


1471-1480: Force skip_prepare_dataset=True for streaming — good

This aligns behavior across train/eval streaming. The warning when user sets False is helpful.


1562-1566: Good: enforce weights length to match dataset count

This check will catch common misconfigs early (covered in tests/e2e). Nicely done.


1585-1586: Wiring StreamingValidationMixin into ValidationMixin — good integration

Exposing streaming validations in the public mixin ensures consistent config behavior across the app.

tests/e2e/test_streaming.py (2)

16-23: Param covers round_robin/weighted/random — verify runtime support for “random”

Validation accepts "random", but ensure the dataset merging logic actually implements it to avoid false positives in tests.

Use the script I provided in the validator comment to confirm "random" handling is implemented.


164-202: Validation error test is solid

This accurately exercises the new weights-length validator. Nice.

src/axolotl/utils/data/sft.py (7)

47-67: Split-aware streaming gate is correct and matches validator semantics

Per-split streaming with eval_streaming override is cleanly handled here. Good alignment with schema validations.


162-172: Compute total steps for IterableDataset via max_steps — correct

This matches streaming constraints (unknown length). Good guard.


318-331: Stream path uses eval overrides — nice separation

Passing the split-specific cfg into _load_raw_datasets ensures eval-only strategy/weights are respected.


401-407: Skip saving streaming datasets — correct

The isinstance(dataset, IterableDataset) guard avoids expensive/invalid save-to-disk for streaming. Good.


427-437: Handle IterableDatasetDict split selection — good

Accounting for both DatasetDict and IterableDatasetDict prevents runtime errors for hub datasets that materialize as iterable dicts under streaming.


471-497: Dedup skip for streaming explicitly logged — good DX

Clear log messaging avoids confusion when dedup is configured but not applicable to IterableDataset.


513-534: Sharding supports both Dataset and IterableDataset

Looks correct; both types implement shard in HF Datasets. Good reuse.

Comment thread src/axolotl/utils/schemas/validation.py
Comment thread src/axolotl/datasets.py Outdated

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
src/axolotl/utils/schemas/validation.py (1)

1483-1526: Comprehensive dataset mixing validation - resolves past review concerns.

The validation correctly:

  1. Validates all mixing strategies including "random" (which was missing in the previous commit)
  2. Enforces weight requirements for "weighted" strategy
  3. Validates weight constraints (non-negative, sum to 1.0)
  4. Matches weight count with dataset count
  5. Warns about ignored weights for non-weighted strategies

This addresses the concern raised in the past review about the missing "random" strategy implementation.

src/axolotl/utils/data/sft.py (1)

420-423: Critical: Use split-specific streaming configuration.

The current implementation correctly determines split-specific streaming but then passes the global cfg to load_dataset_with_config. This could cause issues with evaluation-specific mixing strategies.

Apply this fix to use the split-specific configuration:

 use_streaming_for_split = _is_streaming_enabled_for_split(cfg, split)
+streaming_cfg = _get_streaming_config_for_split(cfg, split) if use_streaming_for_split else cfg
 dataset = load_dataset_with_config(
-    dataset_config, cfg.hf_use_auth_token, use_streaming_for_split
+    dataset_config, streaming_cfg.hf_use_auth_token, use_streaming_for_split
 )

Also note that this aligns with the duplicate comment from the past review about using per-split streaming configuration.

🧹 Nitpick comments (2)
src/axolotl/utils/schemas/validation.py (2)

520-564: Consider simplifying nested conditions.

The nested if statements could be simplified for better readability. However, the logic is correct and the implementation properly handles conflicts between max_steps/num_epochs and epoch-based intervals.

Apply these simplifications:

 def check_saves_per_epoch_conflicts(cls, data):
     """Ensure saves_per_epoch is compatible with training configuration."""
     saves_per_epoch = data.get("saves_per_epoch")
     num_epochs = data.get("num_epochs")
 
-    if saves_per_epoch is not None:
-        # Check if saves_per_epoch is set but num_epochs is unset
-        if num_epochs is None:
-            raise ValueError(
-                "saves_per_epoch requires num_epochs to be set to calculate save "
-                "intervals."
-            )
+    if saves_per_epoch is not None and num_epochs is None:
+        raise ValueError(
+            "saves_per_epoch requires num_epochs to be set to calculate save "
+            "intervals."
+        )
 
     return data

 def check_evals_per_epoch_conflicts(cls, data):
     """Ensure evals_per_epoch is compatible with training configuration."""
     evals_per_epoch = data.get("evals_per_epoch")
     num_epochs = data.get("num_epochs")
 
-    if evals_per_epoch is not None:
-        if num_epochs is None:
-            raise ValueError(
-                "evals_per_epoch requires num_epochs to be set to calculate "
-                "evaluation intervals."
-            )
+    if evals_per_epoch is not None and num_epochs is None:
+        raise ValueError(
+            "evals_per_epoch requires num_epochs to be set to calculate "
+            "evaluation intervals."
+        )
 
     return data

1461-1462: Simplify nested condition check.

Combine the nested if statements for better readability.

-        if self._is_streaming_enabled("train") or self._is_streaming_enabled("eval"):
-            if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
-                raise ValueError("preprocess is not supported for streaming datasets")
+        if (self._is_streaming_enabled("train") or self._is_streaming_enabled("eval")) \
+           and os.environ.get("AXOLOTL_IS_PREPROCESS") == "1":
+            raise ValueError("preprocess is not supported for streaming datasets")
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between c03e310 and 16f7c50.

📒 Files selected for processing (14)
  • src/axolotl/cli/args.py (0 hunks)
  • src/axolotl/common/datasets.py (0 hunks)
  • src/axolotl/datasets.py (3 hunks)
  • src/axolotl/utils/data/sft.py (9 hunks)
  • src/axolotl/utils/data/shared.py (4 hunks)
  • src/axolotl/utils/data/utils.py (1 hunks)
  • src/axolotl/utils/data/wrappers.py (0 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • src/axolotl/utils/schemas/training.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (6 hunks)
  • src/axolotl/utils/trainer.py (1 hunks)
  • tests/e2e/test_streaming.py (1 hunks)
  • tests/test_datasets.py (4 hunks)
  • tests/test_packed_dataset.py (0 hunks)
💤 Files with no reviewable changes (4)
  • src/axolotl/utils/data/wrappers.py
  • tests/test_packed_dataset.py
  • src/axolotl/cli/args.py
  • src/axolotl/common/datasets.py
🚧 Files skipped from review as they are similar to previous changes (6)
  • src/axolotl/utils/trainer.py
  • src/axolotl/utils/schemas/training.py
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/utils/data/utils.py
  • tests/e2e/test_streaming.py
  • tests/test_datasets.py
🧰 Additional context used
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/validation.py

539-541: Use a single if statement instead of nested if statements

(SIM102)


556-557: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


1461-1462: Use a single if statement instead of nested if statements

(SIM102)

🔇 Additional comments (19)
src/axolotl/utils/data/shared.py (5)

528-532: LGTM! Clean implementation of mixed dataset type handling.

The updated signature and return type now properly support both Dataset and IterableDataset types, aligning with the PR's streaming dataset improvements.


543-550: Smart optimization for single dataset cases.

Great job on the early return logic that avoids unnecessary shuffling for curriculum sampling, disabled shuffling, or streaming datasets. This is both efficient and correct.


551-556: Good conditional dataset shuffling logic.

The logic correctly identifies when pre-merge shuffling is appropriate (only for non-streaming datasets), preventing errors with IterableDataset instances that don't support shuffling.


559-572: Excellent post-merge shuffling with appropriate warnings.

The implementation correctly:

  1. Skips shuffling for IterableDataset instances (which don't support shuffle)
  2. Warns about potential issues when shuffling with curriculum sampling
  3. Provides clear debug logging for all branches

576-610: Well-structured dataset mixing implementation with proper validation.

The implementation correctly handles all mixing strategies including the new "random" strategy that was missing in the previous commit (as noted in past reviews). The equal probability calculation for random mixing (1.0 / len(datasets)) is correct.

The error messages are clear and helpful, properly guiding users when they misconfigure streaming with concatenation.

src/axolotl/datasets.py (1)

48-89: Clean adaptation for streaming dataset support.

The process method now correctly handles both regular and iterable datasets:

  • Skips feature extraction and column removal for IterableDataset (since features aren't available upfront)
  • Properly configures num_proc only for regular datasets (streaming doesn't support multiprocessing)

The implementation is well-structured and maintains backward compatibility.

src/axolotl/utils/schemas/validation.py (4)

1415-1432: Well-designed streaming detection logic.

The _is_streaming_enabled helper correctly handles the complex logic for determining streaming state:

  1. Respects explicit eval_streaming for evaluation context
  2. Falls back to main streaming setting
  3. Defaults to streaming for pretraining datasets when not explicitly configured

This implementation aligns perfectly with the streaming configuration hierarchy described in the PR.


1433-1443: Essential validation for streaming datasets.

Correctly enforces the requirement that max_steps must be set for streaming datasets since the Trainer cannot infer dataset length.


1444-1455: Good guard against incompatible configuration.

Properly prevents validation splits with streaming datasets, which is a technical limitation. The error message clearly explains why this combination isn't supported.


1130-1153: Good consolidation of accelerator configuration.

The new check_streaming_split_batches_accelerate method properly extends the existing pretraining logic to also handle streaming datasets. The implementation correctly configures accelerator settings for both training and evaluation streaming scenarios.

src/axolotl/utils/data/sft.py (9)

47-67: Well-designed split-specific streaming configuration.

The _is_streaming_enabled_for_split function correctly implements the streaming hierarchy:

  1. Respects explicit eval_streaming for test splits
  2. Falls back to main streaming setting
  3. Defaults to streaming for pretraining datasets

This aligns perfectly with the validation logic and provides a clean abstraction.


69-87: Clean configuration override for evaluation splits.

The _get_streaming_config_for_split function properly creates split-specific configurations, allowing evaluation datasets to use different mixing strategies and weights than training datasets.


105-108: Clean routing to specialized dataset preparation.

The simplified dispatch logic correctly routes to the appropriate preparation function based on dataset type.


148-161: Good defensive programming for streaming eval datasets.

The code correctly skips sample packing validation for streaming eval datasets since they don't have a calculable length. The comment clearly explains why this check is necessary.


163-172: Proper step calculation for different dataset types.

The implementation correctly handles step counting:

  • For IterableDataset: Uses cfg.max_steps directly (required for streaming)
  • For regular datasets: Calculates based on dataset size with optional max_steps cap

This aligns with the validation requirement that max_steps must be set for streaming datasets.


318-356: Well-structured caching logic with streaming support.

The implementation correctly:

  1. Skips all caching (hub/disk) for streaming datasets
  2. Maintains backward compatibility with non-streaming caching flow
  3. Uses split-specific streaming configuration

This provides a clean separation between streaming and non-streaming paths.


401-407: Correct handling of streaming datasets.

The code properly avoids saving streaming datasets to disk (which isn't supported) while maintaining the save functionality for regular datasets.


471-497: Comprehensive handling of dataset splits with streaming support.

Both _handle_train_dataset_split and _handle_test_dataset_split functions correctly:

  1. Accept both Dataset and IterableDataset types
  2. Skip deduplication for streaming datasets with appropriate logging
  3. Maintain correct return type signatures

The implementation is thorough and handles all edge cases properly.


499-511: See comment for lines 471-497.

Comment thread src/axolotl/datasets.py
@codecov

codecov Bot commented Aug 21, 2025

Copy link
Copy Markdown

Comment thread src/axolotl/cli/args.py
@salmanmohammadi

Copy link
Copy Markdown
Contributor

What did you find in your manual testing?

Comment thread src/axolotl/datasets.py
@djsaunde

djsaunde commented Aug 21, 2025

Copy link
Copy Markdown
Collaborator Author

What did you find in your manual testing?

@salmanmohammadi Basic tests worked with no noticeable extra data time. I should get benchmark numbers (esp for more complicated situations; e.g., multidataset, multiprocess) before merge, though.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (2)
src/axolotl/utils/data/shared.py (1)

576-610: Unify default strategy with schema; add a defensive check for weighted

  • The function falls back to "concatenate" if the key is missing, but the schema default is "round_robin". Use the same default here to avoid surprises in call sites that pass a plain dict lacking the key.
  • Add a minimal runtime check for weighted probabilities (length and sum) to fail fast even if validation is bypassed for some code paths (e.g., merging test_datasets with global mixing settings but no prior validation).

Apply:

-    strategy = cfg.get("dataset_mixing_strategy", "concatenate")
+    strategy = cfg.get("dataset_mixing_strategy", "round_robin")
     weights = cfg.get("mixing_weights", None)

@@
-    if strategy == "weighted":
-        return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed)
+    if strategy == "weighted":
+        if weights is None:
+            raise ValueError("mixing_weights must be provided when strategy='weighted'")
+        if len(weights) != len(datasets):
+            raise ValueError(
+                f"mixing_weights length ({len(weights)}) must match number of datasets ({len(datasets)})"
+            )
+        if not (abs(sum(weights) - 1.0) <= 1e-6):
+            raise ValueError(f"mixing_weights must sum to 1.0, got {sum(weights)}")
+        return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed)

Also, nice job adding the "random" branch—this addresses prior feedback about config/implementation parity.

src/axolotl/utils/data/sft.py (1)

288-301: Critical: per-split streaming decision ignored for ‘test’; eval streaming cannot be enabled

Both _load_tokenized_prepared_datasets and _load_and_process_single_dataset only compute use_streaming for the “train” split, so evaluation never streams. This regresses scenarios like streaming=False with eval_streaming=True and matches the error reported in the PR discussion.

Apply this diff to respect per-split streaming in both places:

@@
-    use_streaming = False
-    if split == "train":
-        use_streaming = _is_streaming_enabled(cfg)
+    use_streaming = _is_streaming_enabled_for_split(cfg, split)
@@
-    use_streaming = False
-    if split == "train":
-        use_streaming = _is_streaming_enabled(cfg)
+    use_streaming = _is_streaming_enabled_for_split(cfg, split)

I can also add a small unit/e2e covering {streaming: false, eval_streaming: true} to prevent regressions. Want me to push that?

Also applies to: 391-396

🧹 Nitpick comments (9)
src/axolotl/utils/schemas/config.py (1)

935-952: Type-safety for mixing strategy; align default with merge layer

Great addition. Two improvements:

  • Constrain dataset_mixing_strategy at the type level to prevent invalid values earlier (you already validate later; this complements it).
  • Ensure the default here matches the merge function’s fallback (shared._merge_datasets_with_strategy currently defaults to "concatenate" if key missing).

Apply:

-    dataset_mixing_strategy: str | None = Field(
-        default="round_robin",
+    dataset_mixing_strategy: Literal["concatenate", "round_robin", "weighted", "random"] = Field(
+        default="round_robin",
         json_schema_extra={
             "description": "Strategy for mixing multiple datasets: 'concatenate', 'round_robin' (equal sampling), 'weighted' (use mixing_weights), or 'random' (random sampling with equal probability). Works for both streaming and non-streaming datasets."
         },
     )

Optionally, if you want to tighten weights:

-    mixing_weights: list[float] | None = Field(
+    mixing_weights: list[float] | None = Field(
         default=None,

(Weights length and sum are already validated in StreamingValidationMixin, so this is optional.)

src/axolotl/utils/schemas/validation.py (3)

519-531: Auto-default num_epochs to 1 — prefer int over float

Good guard. Minor nit: defaulting to an int avoids downstream surprises where integer semantics are assumed.

Apply:

-        if max_steps is None and num_epochs is None:
-            data["num_epochs"] = 1.0
+        if max_steps is None and num_epochs is None:
+            data["num_epochs"] = 1

1130-1150: Streaming accelerate defaults — DRY with pretraining path

This mirrors check_pretraining_split_batches_accelerate. Consider factoring shared logic into a small helper to avoid divergence over time.


1409-1529: Streaming validations are comprehensive; tweak messaging and centralize constants

  • The “Validation splits not supported …, skipping” message in check_streaming_validation_splits_conflict raises an error; remove “skipping” to avoid confusion.
  • valid_strategies is hard-coded here and in data/shared.py. Consider centralizing the set (e.g., an Enum or module-level constant) to prevent drift.

Apply:

-                raise ValueError(
-                    "Validation splits not supported for streaming datasets, skipping"
-                )
+                raise ValueError(
+                    "Validation splits are not supported for streaming datasets."
+                )

If you’d like, I can draft a small constants module (e.g., src/axolotl/utils/data/mixing_strategies.py) and wire both the validator and merger to import from it.

tests/e2e/test_streaming.py (1)

31-67: Make FA usage and external dependencies more CI-stable

Two small tweaks to reduce CI flakiness:

  • flash_attention=True may fail on CPU-only runners; consider using a config knob like flash_attention: False (or leave it unset) for these smoke tests.
  • The tests pull public datasets/models; if rate-limited, runs can fail. If available, consider pinning to tiny local fixtures or HF “dummy” datasets for CI.
src/axolotl/utils/data/sft.py (4)

118-124: Clarify/guard: sample_packing with streaming eval

You skip the “too small” step-count check for IterableDataset evals, but if packing is enabled on a streaming eval dataset, is it actually supported? If not, fail fast with a clear error; if yes, document expected behavior.

Optionally add a guard:

     if (
         eval_dataset
         and cfg.sample_packing
         and cfg.eval_sample_packing is not False
         and not isinstance(eval_dataset, IterableDataset)
     ):
         total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
         if total_eval_steps == 0:
             raise ValueError(
                 "eval dataset split is too small for sample_packing. "
                 "You should set `eval_sample_packing: False` in your config."
             )
+
+    # If eval is streaming, either disable packing or explicitly support it.
+    if (
+        isinstance(eval_dataset, IterableDataset)
+        and cfg.sample_packing
+        and cfg.eval_sample_packing is not False
+    ):
+        raise ValueError(
+            "Evaluation with streaming datasets does not support sample_packing. "
+            "Set `eval_sample_packing: False` or disable `sample_packing`."
+        )

If packing on streaming evals is supported in process_datasets_for_packing, ignore this and add a brief comment here stating so.


273-274: Type hint: include IterableDataset in return type

_load_tokenized_prepared_datasets can return an IterableDataset (streaming path), but the annotation excludes it.

Apply:

-) -> tuple[Dataset | DatasetDict, list[Prompter | None]]:
+) -> tuple[Dataset | IterableDataset | DatasetDict, list[Prompter | None]]:

334-337: Type hint: _load_raw_datasets can return IterableDataset

The merged dataset may be streaming; widen the annotation.

-) -> tuple[Dataset, list[Prompter | None]]:
+) -> tuple[Dataset | IterableDataset, list[Prompter | None]]:

445-452: Good: skip exact deduplication for streaming datasets

Dedup on IterableDataset isn’t supported; the explicit skip and log message are appropriate. Consider consolidating the two log messages into a small helper if this pattern appears elsewhere.

Also applies to: 462-467, 472-481

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 5c669fb and 8ebe7d8.

📒 Files selected for processing (14)
  • src/axolotl/cli/args.py (0 hunks)
  • src/axolotl/common/datasets.py (0 hunks)
  • src/axolotl/datasets.py (4 hunks)
  • src/axolotl/utils/data/sft.py (9 hunks)
  • src/axolotl/utils/data/shared.py (4 hunks)
  • src/axolotl/utils/data/utils.py (1 hunks)
  • src/axolotl/utils/schemas/config.py (1 hunks)
  • src/axolotl/utils/schemas/training.py (1 hunks)
  • src/axolotl/utils/schemas/validation.py (6 hunks)
  • src/axolotl/utils/trainer.py (2 hunks)
  • tests/e2e/integrations/test_kd.py (1 hunks)
  • tests/e2e/test_streaming.py (1 hunks)
  • tests/test_datasets.py (4 hunks)
  • tests/test_packed_dataset.py (0 hunks)
💤 Files with no reviewable changes (3)
  • src/axolotl/cli/args.py
  • src/axolotl/common/datasets.py
  • tests/test_packed_dataset.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • src/axolotl/utils/trainer.py
  • src/axolotl/utils/schemas/training.py
  • src/axolotl/utils/data/utils.py
  • tests/test_datasets.py
  • src/axolotl/datasets.py
🧰 Additional context used
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/validation.py

539-541: Use a single if statement instead of nested if statements

(SIM102)


556-557: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)


1453-1454: Use a single if statement instead of nested if statements

(SIM102)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
🔇 Additional comments (15)
src/axolotl/utils/data/shared.py (3)

528-549: Single-dataset path: shuffle guard is correct

The early-return logic for a single dataset properly avoids shuffling for streaming or curriculum cases; otherwise it shuffles deterministically. LGTM.


551-556: Pre-merge per-dataset shuffle is reasonable

Conditioning the pre-merge shuffle on all-Arrow Datasets avoids materializing streaming sources. LGTM.


557-573: Shuffle-after-merge behavior is sensible with clear logging

Skipping shuffle for IterableDataset and warning on curriculum+shuffle is good UX. LGTM.

src/axolotl/utils/schemas/validation.py (4)

6-6: Import addition is fine

os is required for the streaming preprocess guard. Nothing to change.


196-196: Pylint suppression is acceptable here

Given the number of validations, suppressing too-many-public-methods is reasonable.


532-548: saves_per_epoch requires num_epochs — good validation

This prevents runtime division/interval ambiguities. LGTM.


549-564: evals_per_epoch requires num_epochs — good validation

Consistent with saves_per_epoch handling. LGTM.

tests/e2e/test_streaming.py (4)

18-25: Solid coverage for round_robin, weighted, and random

Nice parameterization to hit all three strategies in one test.


88-126: Validation error test is precise

Good negative test asserting both mismatched length and the explanatory phrase. LGTM.


127-171: Three-dataset weighted case expands surface area nicely

Covers non-binary weighting with correct normalization. LGTM.


81-86: Manual Verification Required: Confirm TensorBoard Scalar Tags

I wasn’t able to locate any TensorBoard event files in ${TMPDIR} to automatically inspect the emitted tags. Please run your streaming e2e test locally, point it at a writable log directory, and confirm which scalar keys are emitted. For example:

#!/bin/bash
set -euo pipefail
TEST_LOGDIR=/tmp/tb_logs
pytest tests/e2e/test_streaming.py --capture=no --tb=short -- --logdir="$TEST_LOGDIR"
fd -t f -g 'events.out.tfevents*' "$TEST_LOGDIR" | \
  head -n1 | \
  xargs -I{} python - <<'PY'
import sys
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
ea = EventAccumulator(sys.argv[1]); ea.Reload()
print(sorted(ea.Scalars().keys()))
PY

Then verify:

  • Are both "train/train_loss" and "train/loss" present?
    • If yes, no further action is needed.
    • If only one appears, please align this test’s expected tag (at lines 81–86 in tests/e2e/test_streaming.py) to use "train/loss", matching the other e2e tests, to avoid false negatives.
src/axolotl/utils/data/sft.py (4)

371-377: Good: avoid saving streaming datasets to disk

Skipping persistence for IterableDataset prevents subtle runtime surprises and unnecessary I/O.


400-405: Good: handle both DatasetDict and IterableDatasetDict

Handling both dict types here avoids split selection bugs when sources expose dicts in streaming mode.


12-12: Import: IterableDatasetDict addition looks right

This enables split selection for streaming dataset dicts; aligns with the downstream instance checks.


486-489: Verify shard semantics on streaming vs. map-style datasets

Ensure that calling dataset.shard(...) on an IterableDataset (which returns a contiguous segment) matches your intended splitting strategy, as map-style Dataset.shard(...) applies strided (every-Nth sample) selection. This distinction can affect data balance and mixing when shards are later combined (e.g., with weighted or round-robin strategies).

• Review _apply_dataset_sharding in src/axolotl/utils/data/sft.py (definition at lines 486–494; shard call at lines 502–504) to confirm contiguous-segment sharding is acceptable for iterable streams.
• Review the initial sharding after dataset.shuffle(seed=…) in the same file (lines 412–416) to ensure strided sampling on map-style datasets yields the desired distribution.

Comment thread src/axolotl/utils/data/sft.py
Comment thread src/axolotl/utils/data/sft.py
Comment thread tests/e2e/integrations/test_kd.py

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
src/axolotl/datasets.py (1)

92-103: Potential crash for streaming datasets: dataset.features can be None on IterableDataset; also inconsistent with process() behavior.

IterableDataset often has features=None (e.g., JSONL streams). Calling list(dataset.features.keys()) will raise. In process(), you already skip remove_columns for IterableDataset; wrap_dataset_for_tokenized_prompt should be consistent and handle features safely.

Apply this diff to safely compute remove_columns, add a desc, and reuse the tokenizer’s batch_size when batched:

@@
-    if isinstance(dataset, IterableDataset):
-        map_kwargs = {}
-        if prompt_tokenizer.supports_batched:
-            map_kwargs["batched"] = True
-
-        # Map the dataset and remove original columns
-        return dataset.map(
-            prompt_tokenizer.tokenize_prompt,
-            remove_columns=list(dataset.features.keys()),
-            **map_kwargs,
-        )
+    if isinstance(dataset, IterableDataset):
+        map_kwargs = {}
+        if prompt_tokenizer.supports_batched:
+            map_kwargs["batched"] = True
+            map_kwargs["batch_size"] = getattr(prompt_tokenizer, "batch_size", 1_000)
+        map_kwargs["desc"] = "Tokenizing Prompts"
+
+        # Determine removable columns safely; features may be None for streaming sources
+        features = getattr(dataset, "features", None)
+        remove_columns = list(features.keys()) if features is not None else None
+
+        # Map the dataset; if remove_columns is None, leave originals and rely on downstream collators
+        return dataset.map(
+            prompt_tokenizer.tokenize_prompt,
+            remove_columns=remove_columns,
+            **map_kwargs,
+        )

Follow-up: If leaving originals causes model forward errors (as previously noted in review), you may need a downstream batch filter that whitelists only model-accepted keys when features=None. I can provide a small utility for that if helpful.

src/axolotl/utils/trainer.py (2)

371-375: Eval IterableDataset still uses .filter(...) — align with train path to avoid streaming breakage.

This likely explains the reported eval streaming error. Mirror the train path and use the custom helper with batched=True.

Apply:

-        eval_dataset = eval_dataset.filter(
-            drop_no_trainable_tokens,
-            **filter_map_kwargs,
-            **drop_long_kwargs,
-        )
+        if isinstance(eval_dataset, IterableDataset):
+            eval_dataset = _create_filtered_iterable_dataset(
+                eval_dataset, drop_no_trainable_tokens, batched=True
+            )
+        else:
+            eval_dataset = eval_dataset.filter(
+                drop_no_trainable_tokens,
+                batched=True,
+                **filter_map_kwargs,
+                **drop_long_kwargs,
+            )

Note: filter_map_kwargs was computed from the train dataset type; if train is map-style but eval is iterable (or vice versa), consider computing kwargs per-dataset to avoid passing num_proc to iterable filters.


319-329: Critical bug in label filtering: list-vs-int comparison prevents dropping zero-token samples

The use of np.any(labels != -100) and np.any(row_labels != -100) on plain Python lists always returns a single boolean True (because labels != -100 is a scalar list-vs-int comparison), so rows consisting entirely of -100 are never filtered out. We confirmed these are the only occurrences in the repo (both at src/axolotl/utils/trainer.py:323 and …:327).

Please apply the following patch:

--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -320,7 +320,8 @@ def should_keep_sample(labels):
         # If first element is an int, we assume a single example
         # If it's a list, we assume we're dealing with a batch
-        if isinstance(labels[0], int):
-            # Single example: return a single bool
-            return np.any(labels != -100)
+        if isinstance(labels[0], int):
+            # Single example: return a single bool (drop if all -100)
+            return any(label != -100 for label in labels)
 
         # Batched: 'labels' is a list of lists
         # Return a list of booleans, one per sub-list
-        results = [np.any(row_labels != -100) for row_labels in labels]
-        return results
+        results = [
+            any(label != -100 for label in row_labels)  # per-example drop logic
+            for row_labels in labels
+        ]
+        return results

– Verified only these two calls of np.any(… != -100) exist in the codebase.
– Switching to Python’s built-in any() ensures correct element-wise truth evaluation on lists.

🧹 Nitpick comments (5)
src/axolotl/datasets.py (4)

4-6: Fix minor grammar in the docstring.

Use the contraction “Let’s”.

Apply this diff:

- We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
+ We want this to be a wrapper for an existing dataset that we have loaded. Let's use the

46-46: Narrow the return type or document the expectation.

process() advertises Dataset | IterableDataset, but init immediately accesses .data, which exists only for Dataset. Either narrow the return type to Dataset for this path or document that TokenizedPromptDataset must not be constructed with an IterableDataset.


48-51: Guard against Features.keys() returning a view; pass a concrete list to remove_columns.

Hugging Face APIs expect str or list[str] for remove_columns. Using a keys view can be brittle.

Apply this diff:

-        features = None
-        if not isinstance(dataset, IterableDataset):
-            features = dataset.features.keys()
+        features = None
+        if not isinstance(dataset, IterableDataset):
+            features = list(dataset.features.keys())

Also applies to: 75-80


53-56: Make batch size configurable (avoid hard-coded 1_000).

Different tokenizers/sequences need different batching to avoid OOM or underutilization. Read an optional batch_size from the tokenizer, with a sane default.

Apply this diff:

-        if self.prompt_tokenizer.supports_batched:
-            map_kwargs["batched"] = True
-            map_kwargs["batch_size"] = 1_000
+        if self.prompt_tokenizer.supports_batched:
+            map_kwargs["batched"] = True
+            map_kwargs["batch_size"] = getattr(self.prompt_tokenizer, "batch_size", 1_000)
src/axolotl/utils/trainer.py (1)

621-633: Avoid initializing NCCL for single-process CPU/MPS runs.

Gate the deepspeed comm init on CUDA availability to prevent crashes on non-CUDA setups.

Apply:

-    if (
-        int(os.environ.get("WORLD_SIZE", "1")) == 1
-        and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1"
-    ):
+    if (
+        int(os.environ.get("WORLD_SIZE", "1")) == 1
+        and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1"
+        and torch.cuda.is_available()
+    ):
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 8ebe7d8 and 37f662a.

📒 Files selected for processing (2)
  • src/axolotl/datasets.py (4 hunks)
  • src/axolotl/utils/trainer.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/axolotl/utils/trainer.py (1)
src/axolotl/core/builders/base.py (2)
  • train_dataset (81-82)
  • train_dataset (85-86)
src/axolotl/datasets.py (2)
src/axolotl/prompt_strategies/chat_template.py (1)
  • tokenize_prompt (391-416)
src/axolotl/prompt_tokenizers.py (3)
  • tokenize_prompt (66-67)
  • tokenize_prompt (118-144)
  • tokenize_prompt (259-288)
🪛 Ruff (0.12.2)
src/axolotl/utils/trainer.py

43-43: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


59-59: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: pre-commit
  • GitHub Check: pre-commit
  • GitHub Check: preview
🔇 Additional comments (3)
src/axolotl/datasets.py (1)

1-7: Docstring cleanup looks good.

Thanks for removing references to ConstantLengthDataset and clarifying the middleware-style approach and padding via collators.

src/axolotl/utils/trainer.py (2)

344-357: LGTM: Custom filter path for train IterableDatasets avoids features/caching pitfalls.

This is the right direction for streaming. With the mask-fix above, it should behave correctly for batched filtering.


542-545: Confirm no collation issues when including the length column in the temporary DataLoader.

Since this DataLoader is only used to compute len(data_loader), it won’t iterate examples; still, confirm no side effects (e.g., accidental reuse elsewhere where a collator might choke on length). If issues surface, swap back to a .remove_columns(["length"]) view specifically for this estimator path.

Comment thread src/axolotl/utils/trainer.py
@djsaunde djsaunde changed the title Better support for streaming datasets Better support for streaming datasets; multidataset weighting / round robin Aug 22, 2025

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
src/axolotl/datasets.py (1)

97-106: Risk: not removing source columns for streaming can break model.forward.

When features is None (common for streaming), remove_columns stays None; original columns persist and often get forwarded to model(**batch), causing unexpected-arg errors. This matches a prior maintainer note about needing to remove extra columns to avoid forward errors.

Two robust options:

  • Option A (fail fast, preferred): require features for streaming datasets so remove_columns is known and safe.
  • Option B (soft fallback): warn loudly when features is None and encourage providing features at load_dataset(..., streaming=True, features=...); still proceed but make the risk explicit.

Minimal changes for Option A:

-        remove_columns = None
-        if dataset.features is not None:
-            remove_columns = list(dataset.features.keys())
+        remove_columns: list[str] | None = None
+        if dataset.features is not None:
+            remove_columns = list(dataset.features.keys())
+        else:
+            # Failing fast avoids silent training/runtime errors later.
+            raise ValueError(
+                "IterableDataset.features is None; cannot determine remove_columns. "
+                "Please create the streaming dataset with explicit `features=...` so "
+                "we can drop source columns after tokenization."
+            )

If you’re not ready to hard-fail, consider this soft fallback:

-        remove_columns = None
-        if dataset.features is not None:
-            remove_columns = list(dataset.features.keys())
+        remove_columns = (
+            list(dataset.features.keys()) if dataset.features is not None else None
+        )
+        if remove_columns is None:
+            LOG.warning(
+                "IterableDataset.features is None; proceeding without remove_columns. "
+                "Upstream code MUST ensure only model-expected keys are forwarded, or "
+                "provide `features` when building the streaming dataset."
+            )

Recommendation: also update dataset loading sites to pass explicit Features for streaming datasets so remove_columns can be computed deterministically.

🧹 Nitpick comments (8)
src/axolotl/datasets.py (8)

1-7: Docstring nit: fix contraction and tighten wording.

Minor grammar and clarity improvements. Also keeps terminology consistent with “middleware-style wrappers”.

- We want this to be a wrapper for an existing dataset that we have loaded. Lets use the
+ We want this to be a wrapper for an existing dataset that we have loaded. Let's use the
  concept of middlewares to wrap each dataset. We'll use the collators later on to pad the
  datasets.

20-28: Class docstring is misleading for non-iterable path.

This class is only constructed for non-Iterable datasets (Iterable path returns a mapped dataset directly). The docstring saying “stream of text files” can confuse readers.

-"""Dataset that returns tokenized prompts from a stream of text files.
+"""Arrow-backed dataset wrapper that returns tokenized prompts.
+
+Note: Streaming (IterableDataset) is handled in wrap_dataset_for_tokenized_prompt
+and does not instantiate this class.

46-51: Prefer column_names and ensure a concrete list for remove_columns.

Using dataset.column_names is the canonical, library-supported way and returning a list avoids surprises when passed into remove_columns.

-        features = None
-        if not isinstance(dataset, IterableDataset):
-            features = dataset.features.keys()
+        features: list[str] | None = None
+        if not isinstance(dataset, IterableDataset):
+            # Stable, explicit list of original columns for remove_columns
+            features = list(dataset.column_names)

61-68: Guard filter_rows with callable check and avoid truthiness pitfalls.

If filter_rows exists but isn’t callable (e.g., set to True/False), dataset.filter will error. Use callable() and bind the name once.

-        if (
-            hasattr(self.prompt_tokenizer, "filter_rows")
-            and self.prompt_tokenizer.filter_rows
-        ):
-            filter_kwargs: dict[str, Any] = {"desc": "Strategy Filtering Rows"}
+        filter_fn = getattr(self.prompt_tokenizer, "filter_rows", None)
+        if callable(filter_fn):
+            filter_kwargs: dict[str, Any] = {"desc": "Strategy Filtering Rows"}
             if not isinstance(dataset, IterableDataset):
                 filter_kwargs["num_proc"] = self.process_count
 
             dataset = dataset.filter(
-                self.prompt_tokenizer.filter_rows,
+                filter_fn,
                 **filter_kwargs,
             )

70-80: Consolidate map kwargs; ensure immutability and future extensibility.

Merging dicts inline is fine, but this block is duplicated with the streaming path below. Consider extracting a small helper that builds map kwargs consistently for both paths.

Follow-up: If you prefer an inline tweak now, you can leave as-is and address deduplication when touching wrap_dataset_for_tokenized_prompt.


92-99: Match batching semantics with non-streaming path (also add desc).

For IterableDataset you set batched=True but omit batch_size and desc. This can affect throughput and observability.

     if isinstance(dataset, IterableDataset):
-        map_kwargs = {}
-        if prompt_tokenizer.supports_batched:
-            map_kwargs["batched"] = True
+        map_kwargs: dict[str, Any] = {"desc": "Tokenizing Prompts"}
+        if prompt_tokenizer.supports_batched:
+            map_kwargs.update({"batched": True, "batch_size": 1_000})

92-109: Reduce duplication: share map option building between paths.

Both branches now assemble highly similar map kwargs. Extract a small helper to keep behavior in lockstep for IterableDataset and Dataset (batching, desc, etc.).

Example shape (outside this function):

def _build_map_kwargs(prompt_tokenizer, *, process_count=None, keep_in_memory=None, is_iterable=False):
    mk = {"desc": "Tokenizing Prompts"}
    if prompt_tokenizer.supports_batched:
        mk.update({"batched": True, "batch_size": 1_000})
    if not is_iterable:
        mk.update({
            "num_proc": process_count,
            "keep_in_memory": keep_in_memory,
        })
    return mk

Then use _build_map_kwargs(...) in both process(...) and wrap_dataset_for_tokenized_prompt(...).


9-17: Unused logger import; consider emitting actionable warnings.

LOG is currently unused. If you adopt the soft fallback for streaming remove_columns, emit a LOG.warning there; otherwise drop the import.

-from axolotl.utils.logging import get_logger
+from axolotl.utils.logging import get_logger
...
-LOG = get_logger(__name__)
+LOG = get_logger(__name__)  # keep only if warning/error is logged for IterableDataset
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 37f662a and 603c79b.

📒 Files selected for processing (1)
  • src/axolotl/datasets.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/axolotl/datasets.py (4)
src/axolotl/prompt_strategies/chat_template.py (1)
  • tokenize_prompt (391-416)
src/axolotl/prompt_tokenizers.py (3)
  • tokenize_prompt (66-67)
  • tokenize_prompt (118-144)
  • tokenize_prompt (259-288)
src/axolotl/prompt_strategies/completion.py (1)
  • tokenize_prompt (42-60)
src/axolotl/prompt_strategies/pretrain.py (1)
  • tokenize_prompt (48-49)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: pre-commit
  • GitHub Check: PyTest (3.11, 2.7.0)
  • GitHub Check: preview
🔇 Additional comments (1)
src/axolotl/datasets.py (1)

46-84: Add tests to safeguard streaming vs non-streaming column removal

I’ve confirmed that:

  • load_dataset(..., streaming=True) is only used in src/axolotl/utils/data/sft.py, without an explicit features= argument, so all streaming datasets will hit the IterableDataset branch in process() (no remove_columns) [see scan results].
  • There are no code paths forwarding entire batches via model(**batch), so unexpected columns would bubble through unless filtered out.
  • The existing remove_columns logic in src/axolotl/datasets.py (lines 46–84) only applies to non-streaming (Dataset) inputs, while the streaming branch currently leaves original columns intact.

To lock in the intended behavior and prevent regressions, please add tests covering:

  • IterableDataset with known schema
    – Construct an IterableDataset(features=...) with an extra dummy column.
    – After process(), assert that only the tokenizer’s output keys (e.g. input_ids, attention_mask) remain.

  • IterableDataset without initial features
    – Use a raw streaming dataset (no .features).
    – Verify that either:

    1. it raises a clear error when first encountering unexpected columns, or
    2. it emits a warning and still only returns the tokenizer’s output keys.
  • Non-streaming (Dataset) path
    – Map over a regular Dataset with multiple columns.
    – Confirm remove_columns drops the originals.
    – Confirm batching is enabled with batch_size=1000 when supports_batched=True.

These tests will ensure the code behaves correctly for both streaming and non-streaming datasets and guard against future changes.

return merged_dataset


def _merge_datasets_with_strategy(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥

yield sample

# Create new IterableDataset from the filtered generator
filtered_dataset = IterableDataset.from_generator(filtered_generator)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
filtered_dataset = IterableDataset.from_generator(filtered_generator)
filtered_dataset = IterableDataset.from_generator(filtered_generator, gen_kwargs={"batch_size": 1000})

batch_size as a kwarg feels better than a constant

@winglian winglian left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a couple of nits. lgtm. Were you able to do any loss curve and step count comparisons to standard SFT w/o streaming?

@djsaunde

Copy link
Copy Markdown
Collaborator Author

Just a couple of nits. lgtm. Were you able to do any loss curve and step count comparisons to standard SFT w/o streaming?

Yes, everything looked okay except for sample packing + streaming, will need to fix.

@winglian

Copy link
Copy Markdown
Collaborator

Yes, everything looked okay except for sample packing + streaming, will need to fix.

packing and streaming is something that I struggled with due to IterableDatasets and BatchSampler within accelerate iirc.
Screenshot 2025-08-23 at 12 50 27 PM

@djsaunde

Copy link
Copy Markdown
Collaborator Author

Closing in favor of #3101.

@djsaunde djsaunde closed this Aug 26, 2025
@djsaunde djsaunde mentioned this pull request Aug 26, 2025
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants