Conversation
|
Warning Rate limit exceeded@djsaunde has exceeded the limit for the number of commits or files that can be reviewed per hour. Please wait 8 minutes and 23 seconds before requesting another review. ⌛ How to resolve this issue?After the wait time has elapsed, a review can be triggered using the We recommend that you space out your commits to avoid hitting the rate limit. 🚦 How do rate limits work?CodeRabbit enforces hourly rate limits for each developer per organization. Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout. Please see our FAQ for further information. 📒 Files selected for processing (15)
📝 WalkthroughWalkthroughRemoves the preprocess_iterable CLI/config flag and shifts to streaming-first dataset support across loading, preparation, merging, tokenization, validation, and trainer paths; adds dataset mixing strategies (concatenate, round_robin, weighted, random); updates tests and e2e flows for streaming. (≤50 words) Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
|
Currently seeing an error when using eval streaming datasets, will need to debug |
|
📖 Documentation Preview: https://68a89608cd414f80a189f3f0--resonant-treacle-0fd729.netlify.app Deployed on Netlify from commit a7edc77 |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/axolotl/utils/data/shared.py (1)
576-612: Bug: cfg.get usage will break with Pydantic config; align default with config and validate weights.
- In this module, cfg is typically a config object with attribute access; using cfg.get will raise if cfg is not Dict-like. Use attribute access or getattr.
- Default strategy here is "concatenate", but config default is "round_robin". Align to avoid surprise behavior when the field is missing.
- For "weighted", validate length and non-negativity and normalize if weights don't sum to 1.0.
- strategy = cfg.get("dataset_mixing_strategy", "concatenate") - weights = cfg.get("mixing_weights", None) + strategy = getattr(cfg, "dataset_mixing_strategy", None) or "round_robin" + weights = getattr(cfg, "mixing_weights", None) @@ - if strategy == "weighted": - return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed) + if strategy == "weighted": + if not weights or len(weights) != len(datasets): + raise ValueError( + "mixing_weights must be provided and match the number of datasets when " + "dataset_mixing_strategy='weighted'" + ) + if min(weights) < 0: + raise ValueError("mixing_weights must be non-negative") + total = sum(weights) + # Normalize if needed + norm_weights = [w / total for w in weights] if total > 0 else weights + return interleave_datasets(datasets, probabilities=norm_weights, seed=cfg.seed)Add this import at the top of the file if you choose to use math.isclose for exactness instead of simple normalization:
import mathAlternatively, keep simple normalization as shown above to avoid extra imports.
🧹 Nitpick comments (12)
src/axolotl/utils/schemas/training.py (1)
164-169: Good JSON-schema enrichment for num_epochs. Consider clarifying streaming interplay.Since this PR introduces streaming, it helps to state that max_steps typically governs iteration when streaming is enabled.
- num_epochs: float = Field( - default=1.0, - json_schema_extra={ - "description": "Number of iterations over dataset for training" - }, - ) + num_epochs: float = Field( + default=1.0, + json_schema_extra={ + "description": "Number of passes over the dataset for training. When using streaming datasets, training is typically governed by `max_steps` (recommended to set) and `num_epochs` may not correspond to a full pass over data." + }, + )src/axolotl/datasets.py (3)
49-85: Avoid large batched mapping on streaming to prevent memory spikes; also prefer column_names for remove_columns.
- Setting batch_size for IterableDataset can accumulate large in-memory batches; better to only set batch_size for non-streaming datasets.
- Using dataset.column_names (or list(features)) is a safer/clearer input for remove_columns than a keys view.
- features = None - if not isinstance(dataset, IterableDataset): - features = dataset.features.keys() - - map_kwargs: dict[str, Any] = {} - if self.prompt_tokenizer.supports_batched: - map_kwargs["batched"] = True - map_kwargs["batch_size"] = 1_000 + features = None + map_kwargs: dict[str, Any] = {} + if not isinstance(dataset, IterableDataset): + # Prefer column_names; ensure it's a concrete list + features = list(dataset.column_names) + if self.prompt_tokenizer.supports_batched: + map_kwargs["batched"] = True + map_kwargs["batch_size"] = 1_000 + else: + if self.prompt_tokenizer.supports_batched: + # For streaming, enable batched but omit batch_size to avoid large in-memory batches + map_kwargs["batched"] = True
80-85: remove_columns wiring is correct; minor type-safety nit.Since features can be a keys view in some contexts, ensure it's a list (covered by the previous suggestion). Otherwise, this looks good.
108-110: If ConstantLengthDataset is test-only, move it under tests to reduce public surface.If nothing in src imports this except tests, consider relocating it to tests and importing locally in the test to avoid shipping unused runtime code.
Would you like me to draft the test-side helper and remove this class from the package code?
src/axolotl/utils/schemas/config.py (1)
935-970: New streaming and mixing config fields are well-scoped; consider stricter typing and validation hints.
- dataset_mixing_strategy and eval_dataset_mixing_strategy could benefit from Literal types for schema/validation.
- Ensure validations enforce:
- mixing_weights length equals number of datasets, non-negative, and sum to 1.0 (or normalized consistently).
- When streaming=True (or eval_streaming=True), max_steps is provided (as documented).
- dataset_mixing_strategy: str | None = Field( + dataset_mixing_strategy: Literal["concatenate", "round_robin", "weighted", "random"] | None = Field( default="round_robin", json_schema_extra={ "description": "Strategy for mixing multiple datasets: 'concatenate', 'round_robin' (equal sampling), 'weighted' (use mixing_weights), or 'random' (random sampling with equal probability). Works for both streaming and non-streaming datasets." }, ) @@ - eval_dataset_mixing_strategy: str | None = Field( + eval_dataset_mixing_strategy: Literal["concatenate", "round_robin", "weighted", "random"] | None = Field( default=None, json_schema_extra={ "description": "Strategy for mixing multiple evaluation datasets. If not set, falls back to dataset_mixing_strategy. Options: 'concatenate', 'round_robin', 'weighted', 'random'." }, )If validations are already added in utils/schemas/validation.py, ignore this and consider marking them in the docstrings for discoverability.
src/axolotl/utils/schemas/validation.py (6)
519-531: Defaulting num_epochs to 1 is good; prefer int over floatSetting a default is helpful. Consider using an int to avoid subtle type inconsistencies downstream.
- if max_steps is None and num_epochs is None: - data["num_epochs"] = 1.0 + if max_steps is None and num_epochs is None: + data["num_epochs"] = 1
532-548: Combine nested conditions (ruff SIM102)Simplify the nested if to a single condition for clarity.
- if saves_per_epoch is not None: - # Check if saves_per_epoch is set but num_epochs is unset - if num_epochs is None: - raise ValueError( - "saves_per_epoch requires num_epochs to be set to calculate save " - "intervals." - ) + if saves_per_epoch is not None and num_epochs is None: + raise ValueError( + "saves_per_epoch requires num_epochs to be set to calculate save intervals." + )
549-564: Combine nested conditions (ruff SIM102)Mirror the change above for evals_per_epoch.
- if evals_per_epoch is not None: - if num_epochs is None: - raise ValueError( - "evals_per_epoch requires num_epochs to be set to calculate " - "evaluation intervals." - ) + if evals_per_epoch is not None and num_epochs is None: + raise ValueError( + "evals_per_epoch requires num_epochs to be set to calculate evaluation intervals." + )
1420-1431: Error message says “skipping” but you raise—reword to avoid confusionYou raise a ValueError, so “skipping” is misleading.
- raise ValueError( - "Validation splits not supported for streaming datasets, skipping" - ) + raise ValueError( + "Validation splits (val_set_size > 0) are not supported with streaming datasets." + )
1433-1442: Optional: avoid redundant checks (ruff SIM102-like)Compute the “streaming enabled” condition once to avoid double-calling.
- if self._is_streaming_enabled("train") or self._is_streaming_enabled("eval"): - if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1": + streaming_enabled = self._is_streaming_enabled("train") or self._is_streaming_enabled("eval") + if streaming_enabled and os.environ.get("AXOLOTL_IS_PREPROCESS") == "1": raise ValueError("preprocess is not supported for streaming datasets")
1503-1548: Optional: use math.isclose for summation toleranceUsing math.isclose improves readability for the 1.0 sum check.
- if abs(sum(weights) - 1.0) > 1e-6: - raise ValueError(f"{weights_field} must sum to 1.0, got {sum(weights)}") + import math + total = sum(weights) + if not math.isclose(total, 1.0, rel_tol=0.0, abs_tol=1e-6): + raise ValueError(f"{weights_field} must sum to 1.0, got {total}")tests/test_datasets.py (1)
637-667: Strengthen assertion to verify both sources appearlen(set(sources)) >= 1 is always true. Check for both sources to ensure interleave correctness.
- sources = [sample["source"] for sample in samples] - assert len(set(sources)) >= 1 # At least one unique source + sources = [sample["source"] for sample in samples] + # Expect samples from both streams + assert len(set(sources)) >= 2
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (12)
src/axolotl/cli/args.py(0 hunks)src/axolotl/common/datasets.py(0 hunks)src/axolotl/datasets.py(4 hunks)src/axolotl/utils/data/sft.py(8 hunks)src/axolotl/utils/data/shared.py(4 hunks)src/axolotl/utils/data/utils.py(1 hunks)src/axolotl/utils/data/wrappers.py(0 hunks)src/axolotl/utils/schemas/config.py(1 hunks)src/axolotl/utils/schemas/training.py(1 hunks)src/axolotl/utils/schemas/validation.py(5 hunks)src/axolotl/utils/trainer.py(1 hunks)tests/test_datasets.py(4 hunks)
💤 Files with no reviewable changes (3)
- src/axolotl/common/datasets.py
- src/axolotl/utils/data/wrappers.py
- src/axolotl/cli/args.py
🧰 Additional context used
🧬 Code Graph Analysis (3)
tests/test_datasets.py (3)
src/axolotl/utils/data/sft.py (4)
_load_tokenized_prepared_datasets(292-351)prepare_datasets(90-107)_is_streaming_enabled_for_split(47-66)_get_streaming_config_for_split(69-86)src/axolotl/utils/dict.py (1)
DictDefault(6-38)src/axolotl/utils/data/shared.py (1)
_merge_datasets_with_strategy(576-611)
src/axolotl/utils/data/shared.py (1)
src/axolotl/utils/dict.py (1)
DictDefault(6-38)
src/axolotl/utils/data/sft.py (4)
src/axolotl/utils/dict.py (1)
DictDefault(6-38)src/axolotl/utils/data/utils.py (2)
retry_on_request_exceptions(31-70)deduplicate_and_log_datasets(112-148)src/axolotl/utils/trainer.py (1)
calculate_total_num_steps(392-519)src/axolotl/utils/data/shared.py (4)
generate_dataset_hash_from_config(506-525)try_load_from_hub(486-503)load_preprocessed_dataset(455-483)save_preprocessed_dataset(407-452)
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/validation.py
539-541: Use a single if statement instead of nested if statements
(SIM102)
556-557: Use a single if statement instead of nested if statements
Combine if statements using and
(SIM102)
1437-1438: Use a single if statement instead of nested if statements
(SIM102)
🔇 Additional comments (31)
src/axolotl/utils/trainer.py (1)
550-553: Typo fix acknowledged; comment is now accurate.The note clarifies the correct timing of distributed state init relative to ACCELERATE_USE_DEEPSPEED. No functional changes introduced.
src/axolotl/datasets.py (1)
1-11: Docstring and typing import look good.Clearer module overview and Any import are appropriate for the added map/filter kwargs.
src/axolotl/utils/data/shared.py (2)
551-556: Pre-merge per-dataset shuffle is guarded correctly.Only shuffles when all inputs are regular Datasets, avoiding IterableDataset pitfalls. Good.
557-573: Post-merge shuffle logic is sensible; good warning for curriculum_sampling.Skips shuffle for streaming datasets and warns when curriculum sampling is on. Works as expected.
src/axolotl/utils/schemas/validation.py (8)
6-6: Import looks correctImporting os is required for the new env-var checks below.
196-196: Pylint suppression is reasonableGiven the aggregation of many validators in this mixin, suppressing too-many-public-methods here is acceptable.
513-515: Nit: tighten the error messageThe error message says “cannot be used together.” That’s fine, just noting the change is cosmetic. No action needed.
1388-1410: Nice addition: central streaming enablement logicThe per-context _is_streaming_enabled helper provides a clear source of truth and will reduce drift across call sites.
1411-1419: Requiring max_steps with streaming is correctTrainer can’t infer length on streams; this validator is essential.
1444-1456: Auto-enforcing skip_prepare_dataset=True under streaming is goodThe warning for explicit False is helpful; the auto-flip prevents expensive/invalid ops.
1458-1501: Mixing strategy/weights validation is solidCovers invalid strategies, non-negative numeric weights, sum-to-1, and length matching. Good.
1561-1561: Good to include StreamingValidationMixin in ValidationMixinEnsures streaming validations run by default across config usage.
src/axolotl/utils/data/sft.py (12)
12-12: Import IterableDatasetDict is necessaryYou correctly handle both DatasetDict and IterableDatasetDict downstream.
47-67: Per-split streaming toggle looks correctThe precedence of eval_streaming (for test split), then streaming, then pretraining-default is sensible.
69-87: Eval-specific mixing overrides are handled cleanlyUsing a shallow DictDefault copy to override only mixing fields is appropriate here.
106-108: Early branch to pretraining path is fineRemoves the old preprocess_iterable surface and simplifies prepare_datasets.
156-166: Correct: derive total steps from max_steps for streaming train datasetsRelying on cfg.max_steps for IterableDataset avoids unreliable step estimates. This assumes validators ensure max_steps is present when streaming.
If we want to harden this path for unvalidated cfg (e.g., in unit tests), we could add a defensive check:
- If train_dataset is IterableDataset and cfg.max_steps is falsy, raise a descriptive error early.
Do you want a patch for that?
310-351: Streaming vs. cached loading paths are well separatedThe streaming path bypasses caching and relies on raw loader; non-streaming still benefits from hub/disk caching. Good design.
395-401: Skip persisting streaming datasetsGuarding save_preprocessed_dataset to only run for non-iterable datasets prevents invalid save attempts. Good catch.
420-420: Split selection supports IterableDatasetDictGood addition to support dict-based iterable datasets.
465-472: Val split size parsing is clearCovers both absolute (>1) and fractional (0-1] expressions.
482-487: Dedup correctly skipped for streaming datasetsThis matches utils.handle_long_seq_in_dataset behavior for streams.
493-501: Consistent dedup behavior for eval splitMirrors the train handling. Looks good.
506-527: Shard application works for both Dataset and IterableDatasetGood reuse via a shared helper, and logging shard selection is helpful.
tests/test_datasets.py (7)
27-28: OK to suppress too-many-public-methods for this test classKeeps pylint quiet without hiding real issues.
50-67: Streaming fixture looks fineSimple generator-based IterableDataset is appropriate for tests.
509-550: Good end-to-end streaming testCovers type, step count, and basic sample structure. One note: since tests bypass schema validators, they rely on utils.handle_long_seq_in_dataset to gracefully no-op for streams (fixed above).
551-571: Strategy validation smoke test is reasonableExercising _merge_datasets_with_strategy across valid strategies is useful.
572-597: Round-robin mixing test looks goodInterleaving assertion is sensible.
597-637: Weighted mixing test is fineLightweight proportionality and representation checks make sense here.
684-706: Eval mixing overrides test is accurateConfirms that eval-specific settings shadow the main ones for the test split.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
src/axolotl/utils/data/sft.py (2)
210-236: Potential KeyError extracting pretraining configAccessing config["name"], config["skip"], config["data_files"] will KeyError when omitted. Use .get with defaults.
Apply this diff:
return DictDefault( { "path": config["path"], - "name": config["name"], - "skip": config["skip"], + "name": config.get("name"), + "skip": config.get("skip", 0), "split": config.get("split", "train"), - "data_files": config.get("data_files"), + "data_files": config.get("data_files"), "type": config.get("type", "pretrain"), } )
252-267: Be defensive accessing dispatch_batches from accelerator_configaccelerator_config can be a dict or object; attribute access may fail if the key is absent. Use safe access to avoid AttributeError.
Apply this diff:
- if ( - cfg.accelerator_config - and cfg.accelerator_config.dispatch_batches - and not is_local_main_process() - ): + accel = cfg.accelerator_config + dispatch_batches = ( + accel.get("dispatch_batches") + if isinstance(accel, dict) + else getattr(accel, "dispatch_batches", None) + ) if accel else None + if accel and dispatch_batches and not is_local_main_process(): iter_dataset = _create_placeholder_dataset() else:
♻️ Duplicate comments (1)
src/axolotl/utils/data/sft.py (1)
415-423: Per-split streaming flag now passed to loader (addresses earlier review)This fixes the previously reported bug where cfg.streaming was used indiscriminately for all splits. Thanks for addressing.
🧹 Nitpick comments (13)
src/axolotl/utils/schemas/validation.py (6)
519-531: Sensible default: auto-set num_epochs=1.0 when steps/epochs unsetGood call to set a sane default; it unblocks downstream logic in normalize_config that derives save/eval steps. Consider documenting this behavior in the user-facing config docs.
532-548: Flatten nested condition for clarity (SIM102)You can simplify the nested if to a single condition.
Apply this diff:
- if saves_per_epoch is not None: - # Check if saves_per_epoch is set but num_epochs is unset - if num_epochs is None: - raise ValueError( - "saves_per_epoch requires num_epochs to be set to calculate save " - "intervals." - ) + # Check if saves_per_epoch is set but num_epochs is unset + if saves_per_epoch is not None and num_epochs is None: + raise ValueError( + "saves_per_epoch requires num_epochs to be set to calculate save " + "intervals." + )
549-564: Flatten nested condition for clarity (SIM102)Same idea here; combine into a single if.
Apply this diff:
- if evals_per_epoch is not None: - if num_epochs is None: - raise ValueError( - "evals_per_epoch requires num_epochs to be set to calculate " - "evaluation intervals." - ) + if evals_per_epoch is not None and num_epochs is None: + raise ValueError( + "evals_per_epoch requires num_epochs to be set to calculate " + "evaluation intervals." + )
1445-1455: Error message says “skipping” but raises — adjust messageThe exception text implies a soft skip; it actually hard-fails. Reword to avoid confusion.
Apply this diff:
- raise ValueError( - "Validation splits not supported for streaming datasets, skipping" - ) + raise ValueError( + "Validation splits are not supported with streaming datasets. Please set `val_set_size: 0` or use `test_datasets`." + )
1457-1465: Preprocess guard via env var is fine, but consider cfg flag tooThis correctly blocks preprocess runs with streaming. Optional future enhancement: also respect a config-level is_preprocess (if present) to avoid relying solely on environment.
1558-1560: Use isclose for weight sum toleranceFloating point sums can be finicky. Use math.isclose for robustness.
Apply this diff:
- if abs(sum(weights) - 1.0) > 1e-6: - raise ValueError(f"{weights_field} must sum to 1.0, got {sum(weights)}") + import math + if not math.isclose(sum(weights), 1.0, rel_tol=1e-9, abs_tol=1e-6): + raise ValueError(f"{weights_field} must sum to 1.0, got {sum(weights)}")tests/e2e/test_streaming.py (3)
29-65: E2E config likely requires GPU; add skip-if-no-GPU to stabilize CIflash_attention=True and optimizer="adamw_torch_fused" typically require a CUDA environment and supported libs. To avoid sporadic CI failures on CPU jobs, gate these tests on GPU availability.
Apply this diff to add torch import and a class-level skip:
@@ -import pytest +import pytest +import torch @@ -class TestStreamingDatasets: +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="GPU required for flash-attn and fused optimizer in streaming E2E tests", +) +class TestStreamingDatasets:Alternatively, drop flash_attention/fused optimizer in these smoke tests to allow CPU runs.
78-85: TensorBoard threshold is quite tight for smoke testsA fixed <2.5 loss can be flaky across seeds/datasets. Consider relaxing or asserting presence of metrics rather than absolute value to reduce flakiness.
139-141: Eval cadence: eval_steps=3 with max_steps=3This triggers eval only at end; OK. If you want to see intermediate eval as well, set eval_steps=1. Not a blocker.
src/axolotl/utils/data/sft.py (4)
69-86: Shallow-copied cfg for eval overrides — acceptable, but note mutabilityDictDefault(cfg) creates a shallow copy; only top-level overrides are adjusted, which is fine for strategy/weights. If deeper eval-only overrides are added later, consider copy.deepcopy to avoid aliasing surprises.
148-155: Fix typo in commentMinor typo in the comment makes it hard to read.
Apply this diff:
- # Skip validation for streaming eval datasets since theWhat hy don't have a calculable length + # Skip validation for streaming eval datasets since they don't have a calculable length
298-304: Return type should include IterableDataset/IterableDatasetDictWhen streaming, this function returns IterableDataset or IterableDatasetDict; reflect that in the type hint.
Apply this diff:
-) -> tuple[Dataset | DatasetDict, list[Prompter | None]]: +) -> tuple[Dataset | IterableDataset | DatasetDict | IterableDatasetDict, list[Prompter | None]]:
360-367: Return type should include IterableDatasetThe streaming case returns IterableDataset. Update the annotation accordingly.
Apply this diff:
-) -> tuple[Dataset, list[Prompter | None]]: +) -> tuple[Dataset | IterableDataset, list[Prompter | None]]:
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (3)
src/axolotl/utils/data/sft.py(9 hunks)src/axolotl/utils/schemas/validation.py(6 hunks)tests/e2e/test_streaming.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (2)
tests/e2e/test_streaming.py (3)
src/axolotl/utils/config/__init__.py (2)
normalize_config(97-237)validate_config(259-303)src/axolotl/utils/dict.py (1)
DictDefault(6-38)tests/e2e/utils.py (2)
check_model_output_exists(171-192)check_tensorboard(149-168)
src/axolotl/utils/data/sft.py (4)
src/axolotl/utils/dict.py (1)
DictDefault(6-38)src/axolotl/utils/data/utils.py (1)
deduplicate_and_log_datasets(112-148)src/axolotl/utils/trainer.py (1)
calculate_total_num_steps(392-519)src/axolotl/utils/data/shared.py (4)
try_load_from_hub(486-503)load_preprocessed_dataset(455-483)save_preprocessed_dataset(407-452)load_dataset_with_config(94-149)
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/validation.py
539-541: Use a single if statement instead of nested if statements
(SIM102)
556-557: Use a single if statement instead of nested if statements
Combine if statements using and
(SIM102)
1461-1462: Use a single if statement instead of nested if statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: preview
🔇 Additional comments (14)
src/axolotl/utils/schemas/validation.py (5)
1130-1153: Accelerate split/dispatch defaults for streaming look correctEnsuring split_batches/dispatch_batches=False when streaming (train or eval) is enabled avoids IterableDataset incompatibilities with batch splitting. Looks good as a defensive default.
1434-1443: Max-steps requirement for streaming is correctRequiring max_steps when streaming prevents unknown-length dataset issues with schedulers/optimizers. Solid.
1471-1480: Force skip_prepare_dataset=True for streaming — goodThis aligns behavior across train/eval streaming. The warning when user sets False is helpful.
1562-1566: Good: enforce weights length to match dataset countThis check will catch common misconfigs early (covered in tests/e2e). Nicely done.
1585-1586: Wiring StreamingValidationMixin into ValidationMixin — good integrationExposing streaming validations in the public mixin ensures consistent config behavior across the app.
tests/e2e/test_streaming.py (2)
16-23: Param covers round_robin/weighted/random — verify runtime support for “random”Validation accepts "random", but ensure the dataset merging logic actually implements it to avoid false positives in tests.
Use the script I provided in the validator comment to confirm "random" handling is implemented.
164-202: Validation error test is solidThis accurately exercises the new weights-length validator. Nice.
src/axolotl/utils/data/sft.py (7)
47-67: Split-aware streaming gate is correct and matches validator semanticsPer-split streaming with eval_streaming override is cleanly handled here. Good alignment with schema validations.
162-172: Compute total steps for IterableDataset via max_steps — correctThis matches streaming constraints (unknown length). Good guard.
318-331: Stream path uses eval overrides — nice separationPassing the split-specific cfg into _load_raw_datasets ensures eval-only strategy/weights are respected.
401-407: Skip saving streaming datasets — correctThe isinstance(dataset, IterableDataset) guard avoids expensive/invalid save-to-disk for streaming. Good.
427-437: Handle IterableDatasetDict split selection — goodAccounting for both DatasetDict and IterableDatasetDict prevents runtime errors for hub datasets that materialize as iterable dicts under streaming.
471-497: Dedup skip for streaming explicitly logged — good DXClear log messaging avoids confusion when dedup is configured but not applicable to IterableDataset.
513-534: Sharding supports both Dataset and IterableDatasetLooks correct; both types implement shard in HF Datasets. Good reuse.
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
src/axolotl/utils/schemas/validation.py (1)
1483-1526: Comprehensive dataset mixing validation - resolves past review concerns.The validation correctly:
- Validates all mixing strategies including "random" (which was missing in the previous commit)
- Enforces weight requirements for "weighted" strategy
- Validates weight constraints (non-negative, sum to 1.0)
- Matches weight count with dataset count
- Warns about ignored weights for non-weighted strategies
This addresses the concern raised in the past review about the missing "random" strategy implementation.
src/axolotl/utils/data/sft.py (1)
420-423: Critical: Use split-specific streaming configuration.The current implementation correctly determines split-specific streaming but then passes the global
cfgtoload_dataset_with_config. This could cause issues with evaluation-specific mixing strategies.Apply this fix to use the split-specific configuration:
use_streaming_for_split = _is_streaming_enabled_for_split(cfg, split) +streaming_cfg = _get_streaming_config_for_split(cfg, split) if use_streaming_for_split else cfg dataset = load_dataset_with_config( - dataset_config, cfg.hf_use_auth_token, use_streaming_for_split + dataset_config, streaming_cfg.hf_use_auth_token, use_streaming_for_split )Also note that this aligns with the duplicate comment from the past review about using per-split streaming configuration.
🧹 Nitpick comments (2)
src/axolotl/utils/schemas/validation.py (2)
520-564: Consider simplifying nested conditions.The nested if statements could be simplified for better readability. However, the logic is correct and the implementation properly handles conflicts between
max_steps/num_epochsand epoch-based intervals.Apply these simplifications:
def check_saves_per_epoch_conflicts(cls, data): """Ensure saves_per_epoch is compatible with training configuration.""" saves_per_epoch = data.get("saves_per_epoch") num_epochs = data.get("num_epochs") - if saves_per_epoch is not None: - # Check if saves_per_epoch is set but num_epochs is unset - if num_epochs is None: - raise ValueError( - "saves_per_epoch requires num_epochs to be set to calculate save " - "intervals." - ) + if saves_per_epoch is not None and num_epochs is None: + raise ValueError( + "saves_per_epoch requires num_epochs to be set to calculate save " + "intervals." + ) return data def check_evals_per_epoch_conflicts(cls, data): """Ensure evals_per_epoch is compatible with training configuration.""" evals_per_epoch = data.get("evals_per_epoch") num_epochs = data.get("num_epochs") - if evals_per_epoch is not None: - if num_epochs is None: - raise ValueError( - "evals_per_epoch requires num_epochs to be set to calculate " - "evaluation intervals." - ) + if evals_per_epoch is not None and num_epochs is None: + raise ValueError( + "evals_per_epoch requires num_epochs to be set to calculate " + "evaluation intervals." + ) return data
1461-1462: Simplify nested condition check.Combine the nested if statements for better readability.
- if self._is_streaming_enabled("train") or self._is_streaming_enabled("eval"): - if os.environ.get("AXOLOTL_IS_PREPROCESS") == "1": - raise ValueError("preprocess is not supported for streaming datasets") + if (self._is_streaming_enabled("train") or self._is_streaming_enabled("eval")) \ + and os.environ.get("AXOLOTL_IS_PREPROCESS") == "1": + raise ValueError("preprocess is not supported for streaming datasets")
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (14)
src/axolotl/cli/args.py(0 hunks)src/axolotl/common/datasets.py(0 hunks)src/axolotl/datasets.py(3 hunks)src/axolotl/utils/data/sft.py(9 hunks)src/axolotl/utils/data/shared.py(4 hunks)src/axolotl/utils/data/utils.py(1 hunks)src/axolotl/utils/data/wrappers.py(0 hunks)src/axolotl/utils/schemas/config.py(1 hunks)src/axolotl/utils/schemas/training.py(1 hunks)src/axolotl/utils/schemas/validation.py(6 hunks)src/axolotl/utils/trainer.py(1 hunks)tests/e2e/test_streaming.py(1 hunks)tests/test_datasets.py(4 hunks)tests/test_packed_dataset.py(0 hunks)
💤 Files with no reviewable changes (4)
- src/axolotl/utils/data/wrappers.py
- tests/test_packed_dataset.py
- src/axolotl/cli/args.py
- src/axolotl/common/datasets.py
🚧 Files skipped from review as they are similar to previous changes (6)
- src/axolotl/utils/trainer.py
- src/axolotl/utils/schemas/training.py
- src/axolotl/utils/schemas/config.py
- src/axolotl/utils/data/utils.py
- tests/e2e/test_streaming.py
- tests/test_datasets.py
🧰 Additional context used
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/validation.py
539-541: Use a single if statement instead of nested if statements
(SIM102)
556-557: Use a single if statement instead of nested if statements
Combine if statements using and
(SIM102)
1461-1462: Use a single if statement instead of nested if statements
(SIM102)
🔇 Additional comments (19)
src/axolotl/utils/data/shared.py (5)
528-532: LGTM! Clean implementation of mixed dataset type handling.The updated signature and return type now properly support both
DatasetandIterableDatasettypes, aligning with the PR's streaming dataset improvements.
543-550: Smart optimization for single dataset cases.Great job on the early return logic that avoids unnecessary shuffling for curriculum sampling, disabled shuffling, or streaming datasets. This is both efficient and correct.
551-556: Good conditional dataset shuffling logic.The logic correctly identifies when pre-merge shuffling is appropriate (only for non-streaming datasets), preventing errors with
IterableDatasetinstances that don't support shuffling.
559-572: Excellent post-merge shuffling with appropriate warnings.The implementation correctly:
- Skips shuffling for
IterableDatasetinstances (which don't support shuffle)- Warns about potential issues when shuffling with curriculum sampling
- Provides clear debug logging for all branches
576-610: Well-structured dataset mixing implementation with proper validation.The implementation correctly handles all mixing strategies including the new "random" strategy that was missing in the previous commit (as noted in past reviews). The equal probability calculation for random mixing (
1.0 / len(datasets)) is correct.The error messages are clear and helpful, properly guiding users when they misconfigure streaming with concatenation.
src/axolotl/datasets.py (1)
48-89: Clean adaptation for streaming dataset support.The
processmethod now correctly handles both regular and iterable datasets:
- Skips feature extraction and column removal for
IterableDataset(since features aren't available upfront)- Properly configures
num_proconly for regular datasets (streaming doesn't support multiprocessing)The implementation is well-structured and maintains backward compatibility.
src/axolotl/utils/schemas/validation.py (4)
1415-1432: Well-designed streaming detection logic.The
_is_streaming_enabledhelper correctly handles the complex logic for determining streaming state:
- Respects explicit
eval_streamingfor evaluation context- Falls back to main
streamingsetting- Defaults to streaming for pretraining datasets when not explicitly configured
This implementation aligns perfectly with the streaming configuration hierarchy described in the PR.
1433-1443: Essential validation for streaming datasets.Correctly enforces the requirement that
max_stepsmust be set for streaming datasets since the Trainer cannot infer dataset length.
1444-1455: Good guard against incompatible configuration.Properly prevents validation splits with streaming datasets, which is a technical limitation. The error message clearly explains why this combination isn't supported.
1130-1153: Good consolidation of accelerator configuration.The new
check_streaming_split_batches_acceleratemethod properly extends the existing pretraining logic to also handle streaming datasets. The implementation correctly configures accelerator settings for both training and evaluation streaming scenarios.src/axolotl/utils/data/sft.py (9)
47-67: Well-designed split-specific streaming configuration.The
_is_streaming_enabled_for_splitfunction correctly implements the streaming hierarchy:
- Respects explicit
eval_streamingfor test splits- Falls back to main
streamingsetting- Defaults to streaming for pretraining datasets
This aligns perfectly with the validation logic and provides a clean abstraction.
69-87: Clean configuration override for evaluation splits.The
_get_streaming_config_for_splitfunction properly creates split-specific configurations, allowing evaluation datasets to use different mixing strategies and weights than training datasets.
105-108: Clean routing to specialized dataset preparation.The simplified dispatch logic correctly routes to the appropriate preparation function based on dataset type.
148-161: Good defensive programming for streaming eval datasets.The code correctly skips sample packing validation for streaming eval datasets since they don't have a calculable length. The comment clearly explains why this check is necessary.
163-172: Proper step calculation for different dataset types.The implementation correctly handles step counting:
- For
IterableDataset: Usescfg.max_stepsdirectly (required for streaming)- For regular datasets: Calculates based on dataset size with optional max_steps cap
This aligns with the validation requirement that max_steps must be set for streaming datasets.
318-356: Well-structured caching logic with streaming support.The implementation correctly:
- Skips all caching (hub/disk) for streaming datasets
- Maintains backward compatibility with non-streaming caching flow
- Uses split-specific streaming configuration
This provides a clean separation between streaming and non-streaming paths.
401-407: Correct handling of streaming datasets.The code properly avoids saving streaming datasets to disk (which isn't supported) while maintaining the save functionality for regular datasets.
471-497: Comprehensive handling of dataset splits with streaming support.Both
_handle_train_dataset_splitand_handle_test_dataset_splitfunctions correctly:
- Accept both
DatasetandIterableDatasettypes- Skip deduplication for streaming datasets with appropriate logging
- Maintain correct return type signatures
The implementation is thorough and handles all edge cases properly.
499-511: See comment for lines 471-497.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
|
What did you find in your manual testing? |
@salmanmohammadi Basic tests worked with no noticeable extra data time. I should get benchmark numbers (esp for more complicated situations; e.g., multidataset, multiprocess) before merge, though. |
There was a problem hiding this comment.
Actionable comments posted: 3
♻️ Duplicate comments (2)
src/axolotl/utils/data/shared.py (1)
576-610: Unify default strategy with schema; add a defensive check for weighted
- The function falls back to "concatenate" if the key is missing, but the schema default is "round_robin". Use the same default here to avoid surprises in call sites that pass a plain dict lacking the key.
- Add a minimal runtime check for weighted probabilities (length and sum) to fail fast even if validation is bypassed for some code paths (e.g., merging test_datasets with global mixing settings but no prior validation).
Apply:
- strategy = cfg.get("dataset_mixing_strategy", "concatenate") + strategy = cfg.get("dataset_mixing_strategy", "round_robin") weights = cfg.get("mixing_weights", None) @@ - if strategy == "weighted": - return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed) + if strategy == "weighted": + if weights is None: + raise ValueError("mixing_weights must be provided when strategy='weighted'") + if len(weights) != len(datasets): + raise ValueError( + f"mixing_weights length ({len(weights)}) must match number of datasets ({len(datasets)})" + ) + if not (abs(sum(weights) - 1.0) <= 1e-6): + raise ValueError(f"mixing_weights must sum to 1.0, got {sum(weights)}") + return interleave_datasets(datasets, probabilities=weights, seed=cfg.seed)Also, nice job adding the "random" branch—this addresses prior feedback about config/implementation parity.
src/axolotl/utils/data/sft.py (1)
288-301: Critical: per-split streaming decision ignored for ‘test’; eval streaming cannot be enabledBoth
_load_tokenized_prepared_datasetsand_load_and_process_single_datasetonly computeuse_streamingfor the “train” split, so evaluation never streams. This regresses scenarios likestreaming=Falsewitheval_streaming=Trueand matches the error reported in the PR discussion.Apply this diff to respect per-split streaming in both places:
@@ - use_streaming = False - if split == "train": - use_streaming = _is_streaming_enabled(cfg) + use_streaming = _is_streaming_enabled_for_split(cfg, split) @@ - use_streaming = False - if split == "train": - use_streaming = _is_streaming_enabled(cfg) + use_streaming = _is_streaming_enabled_for_split(cfg, split)I can also add a small unit/e2e covering
{streaming: false, eval_streaming: true}to prevent regressions. Want me to push that?Also applies to: 391-396
🧹 Nitpick comments (9)
src/axolotl/utils/schemas/config.py (1)
935-952: Type-safety for mixing strategy; align default with merge layerGreat addition. Two improvements:
- Constrain dataset_mixing_strategy at the type level to prevent invalid values earlier (you already validate later; this complements it).
- Ensure the default here matches the merge function’s fallback (shared._merge_datasets_with_strategy currently defaults to "concatenate" if key missing).
Apply:
- dataset_mixing_strategy: str | None = Field( - default="round_robin", + dataset_mixing_strategy: Literal["concatenate", "round_robin", "weighted", "random"] = Field( + default="round_robin", json_schema_extra={ "description": "Strategy for mixing multiple datasets: 'concatenate', 'round_robin' (equal sampling), 'weighted' (use mixing_weights), or 'random' (random sampling with equal probability). Works for both streaming and non-streaming datasets." }, )Optionally, if you want to tighten weights:
- mixing_weights: list[float] | None = Field( + mixing_weights: list[float] | None = Field( default=None,(Weights length and sum are already validated in StreamingValidationMixin, so this is optional.)
src/axolotl/utils/schemas/validation.py (3)
519-531: Auto-default num_epochs to 1 — prefer int over floatGood guard. Minor nit: defaulting to an int avoids downstream surprises where integer semantics are assumed.
Apply:
- if max_steps is None and num_epochs is None: - data["num_epochs"] = 1.0 + if max_steps is None and num_epochs is None: + data["num_epochs"] = 1
1130-1150: Streaming accelerate defaults — DRY with pretraining pathThis mirrors check_pretraining_split_batches_accelerate. Consider factoring shared logic into a small helper to avoid divergence over time.
1409-1529: Streaming validations are comprehensive; tweak messaging and centralize constants
- The “Validation splits not supported …, skipping” message in check_streaming_validation_splits_conflict raises an error; remove “skipping” to avoid confusion.
- valid_strategies is hard-coded here and in data/shared.py. Consider centralizing the set (e.g., an Enum or module-level constant) to prevent drift.
Apply:
- raise ValueError( - "Validation splits not supported for streaming datasets, skipping" - ) + raise ValueError( + "Validation splits are not supported for streaming datasets." + )If you’d like, I can draft a small constants module (e.g., src/axolotl/utils/data/mixing_strategies.py) and wire both the validator and merger to import from it.
tests/e2e/test_streaming.py (1)
31-67: Make FA usage and external dependencies more CI-stableTwo small tweaks to reduce CI flakiness:
- flash_attention=True may fail on CPU-only runners; consider using a config knob like flash_attention: False (or leave it unset) for these smoke tests.
- The tests pull public datasets/models; if rate-limited, runs can fail. If available, consider pinning to tiny local fixtures or HF “dummy” datasets for CI.
src/axolotl/utils/data/sft.py (4)
118-124: Clarify/guard: sample_packing with streaming evalYou skip the “too small” step-count check for
IterableDatasetevals, but if packing is enabled on a streaming eval dataset, is it actually supported? If not, fail fast with a clear error; if yes, document expected behavior.Optionally add a guard:
if ( eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False and not isinstance(eval_dataset, IterableDataset) ): total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False) if total_eval_steps == 0: raise ValueError( "eval dataset split is too small for sample_packing. " "You should set `eval_sample_packing: False` in your config." ) + + # If eval is streaming, either disable packing or explicitly support it. + if ( + isinstance(eval_dataset, IterableDataset) + and cfg.sample_packing + and cfg.eval_sample_packing is not False + ): + raise ValueError( + "Evaluation with streaming datasets does not support sample_packing. " + "Set `eval_sample_packing: False` or disable `sample_packing`." + )If packing on streaming evals is supported in
process_datasets_for_packing, ignore this and add a brief comment here stating so.
273-274: Type hint: include IterableDataset in return type
_load_tokenized_prepared_datasetscan return anIterableDataset(streaming path), but the annotation excludes it.Apply:
-) -> tuple[Dataset | DatasetDict, list[Prompter | None]]: +) -> tuple[Dataset | IterableDataset | DatasetDict, list[Prompter | None]]:
334-337: Type hint:_load_raw_datasetscan return IterableDatasetThe merged dataset may be streaming; widen the annotation.
-) -> tuple[Dataset, list[Prompter | None]]: +) -> tuple[Dataset | IterableDataset, list[Prompter | None]]:
445-452: Good: skip exact deduplication for streaming datasetsDedup on
IterableDatasetisn’t supported; the explicit skip and log message are appropriate. Consider consolidating the two log messages into a small helper if this pattern appears elsewhere.Also applies to: 462-467, 472-481
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (14)
src/axolotl/cli/args.py(0 hunks)src/axolotl/common/datasets.py(0 hunks)src/axolotl/datasets.py(4 hunks)src/axolotl/utils/data/sft.py(9 hunks)src/axolotl/utils/data/shared.py(4 hunks)src/axolotl/utils/data/utils.py(1 hunks)src/axolotl/utils/schemas/config.py(1 hunks)src/axolotl/utils/schemas/training.py(1 hunks)src/axolotl/utils/schemas/validation.py(6 hunks)src/axolotl/utils/trainer.py(2 hunks)tests/e2e/integrations/test_kd.py(1 hunks)tests/e2e/test_streaming.py(1 hunks)tests/test_datasets.py(4 hunks)tests/test_packed_dataset.py(0 hunks)
💤 Files with no reviewable changes (3)
- src/axolotl/cli/args.py
- src/axolotl/common/datasets.py
- tests/test_packed_dataset.py
🚧 Files skipped from review as they are similar to previous changes (5)
- src/axolotl/utils/trainer.py
- src/axolotl/utils/schemas/training.py
- src/axolotl/utils/data/utils.py
- tests/test_datasets.py
- src/axolotl/datasets.py
🧰 Additional context used
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/validation.py
539-541: Use a single if statement instead of nested if statements
(SIM102)
556-557: Use a single if statement instead of nested if statements
Combine if statements using and
(SIM102)
1453-1454: Use a single if statement instead of nested if statements
(SIM102)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
🔇 Additional comments (15)
src/axolotl/utils/data/shared.py (3)
528-549: Single-dataset path: shuffle guard is correctThe early-return logic for a single dataset properly avoids shuffling for streaming or curriculum cases; otherwise it shuffles deterministically. LGTM.
551-556: Pre-merge per-dataset shuffle is reasonableConditioning the pre-merge shuffle on all-Arrow Datasets avoids materializing streaming sources. LGTM.
557-573: Shuffle-after-merge behavior is sensible with clear loggingSkipping shuffle for IterableDataset and warning on curriculum+shuffle is good UX. LGTM.
src/axolotl/utils/schemas/validation.py (4)
6-6: Import addition is fineos is required for the streaming preprocess guard. Nothing to change.
196-196: Pylint suppression is acceptable hereGiven the number of validations, suppressing too-many-public-methods is reasonable.
532-548: saves_per_epoch requires num_epochs — good validationThis prevents runtime division/interval ambiguities. LGTM.
549-564: evals_per_epoch requires num_epochs — good validationConsistent with saves_per_epoch handling. LGTM.
tests/e2e/test_streaming.py (4)
18-25: Solid coverage for round_robin, weighted, and randomNice parameterization to hit all three strategies in one test.
88-126: Validation error test is preciseGood negative test asserting both mismatched length and the explanatory phrase. LGTM.
127-171: Three-dataset weighted case expands surface area nicelyCovers non-binary weighting with correct normalization. LGTM.
81-86: Manual Verification Required: Confirm TensorBoard Scalar TagsI wasn’t able to locate any TensorBoard event files in
${TMPDIR}to automatically inspect the emitted tags. Please run your streaming e2e test locally, point it at a writable log directory, and confirm which scalar keys are emitted. For example:#!/bin/bash set -euo pipefail TEST_LOGDIR=/tmp/tb_logs pytest tests/e2e/test_streaming.py --capture=no --tb=short -- --logdir="$TEST_LOGDIR" fd -t f -g 'events.out.tfevents*' "$TEST_LOGDIR" | \ head -n1 | \ xargs -I{} python - <<'PY' import sys from tensorboard.backend.event_processing.event_accumulator import EventAccumulator ea = EventAccumulator(sys.argv[1]); ea.Reload() print(sorted(ea.Scalars().keys())) PYThen verify:
- Are both
"train/train_loss"and"train/loss"present?
- If yes, no further action is needed.
- If only one appears, please align this test’s expected tag (at lines 81–86 in
tests/e2e/test_streaming.py) to use"train/loss", matching the other e2e tests, to avoid false negatives.src/axolotl/utils/data/sft.py (4)
371-377: Good: avoid saving streaming datasets to diskSkipping persistence for
IterableDatasetprevents subtle runtime surprises and unnecessary I/O.
400-405: Good: handle both DatasetDict and IterableDatasetDictHandling both dict types here avoids split selection bugs when sources expose dicts in streaming mode.
12-12: Import: IterableDatasetDict addition looks rightThis enables split selection for streaming dataset dicts; aligns with the downstream instance checks.
486-489: Verify shard semantics on streaming vs. map-style datasetsEnsure that calling
dataset.shard(...)on anIterableDataset(which returns a contiguous segment) matches your intended splitting strategy, as map-styleDataset.shard(...)applies strided (every-Nth sample) selection. This distinction can affect data balance and mixing when shards are later combined (e.g., with weighted or round-robin strategies).• Review
_apply_dataset_shardinginsrc/axolotl/utils/data/sft.py(definition at lines 486–494; shard call at lines 502–504) to confirm contiguous-segment sharding is acceptable for iterable streams.
• Review the initial sharding afterdataset.shuffle(seed=…)in the same file (lines 412–416) to ensure strided sampling on map-style datasets yields the desired distribution.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
src/axolotl/datasets.py (1)
92-103: Potential crash for streaming datasets: dataset.features can be None on IterableDataset; also inconsistent with process() behavior.IterableDataset often has features=None (e.g., JSONL streams). Calling list(dataset.features.keys()) will raise. In process(), you already skip remove_columns for IterableDataset; wrap_dataset_for_tokenized_prompt should be consistent and handle features safely.
Apply this diff to safely compute remove_columns, add a desc, and reuse the tokenizer’s batch_size when batched:
@@ - if isinstance(dataset, IterableDataset): - map_kwargs = {} - if prompt_tokenizer.supports_batched: - map_kwargs["batched"] = True - - # Map the dataset and remove original columns - return dataset.map( - prompt_tokenizer.tokenize_prompt, - remove_columns=list(dataset.features.keys()), - **map_kwargs, - ) + if isinstance(dataset, IterableDataset): + map_kwargs = {} + if prompt_tokenizer.supports_batched: + map_kwargs["batched"] = True + map_kwargs["batch_size"] = getattr(prompt_tokenizer, "batch_size", 1_000) + map_kwargs["desc"] = "Tokenizing Prompts" + + # Determine removable columns safely; features may be None for streaming sources + features = getattr(dataset, "features", None) + remove_columns = list(features.keys()) if features is not None else None + + # Map the dataset; if remove_columns is None, leave originals and rely on downstream collators + return dataset.map( + prompt_tokenizer.tokenize_prompt, + remove_columns=remove_columns, + **map_kwargs, + )Follow-up: If leaving originals causes model forward errors (as previously noted in review), you may need a downstream batch filter that whitelists only model-accepted keys when features=None. I can provide a small utility for that if helpful.
src/axolotl/utils/trainer.py (2)
371-375: Eval IterableDataset still uses.filter(...)— align with train path to avoid streaming breakage.This likely explains the reported eval streaming error. Mirror the train path and use the custom helper with
batched=True.Apply:
- eval_dataset = eval_dataset.filter( - drop_no_trainable_tokens, - **filter_map_kwargs, - **drop_long_kwargs, - ) + if isinstance(eval_dataset, IterableDataset): + eval_dataset = _create_filtered_iterable_dataset( + eval_dataset, drop_no_trainable_tokens, batched=True + ) + else: + eval_dataset = eval_dataset.filter( + drop_no_trainable_tokens, + batched=True, + **filter_map_kwargs, + **drop_long_kwargs, + )Note:
filter_map_kwargswas computed from the train dataset type; if train is map-style but eval is iterable (or vice versa), consider computing kwargs per-dataset to avoid passingnum_procto iterable filters.
319-329: Critical bug in label filtering: list-vs-int comparison prevents dropping zero-token samplesThe use of
np.any(labels != -100)andnp.any(row_labels != -100)on plain Python lists always returns a single booleanTrue(becauselabels != -100is a scalar list-vs-int comparison), so rows consisting entirely of-100are never filtered out. We confirmed these are the only occurrences in the repo (both atsrc/axolotl/utils/trainer.py:323and…:327).Please apply the following patch:
--- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -320,7 +320,8 @@ def should_keep_sample(labels): # If first element is an int, we assume a single example # If it's a list, we assume we're dealing with a batch - if isinstance(labels[0], int): - # Single example: return a single bool - return np.any(labels != -100) + if isinstance(labels[0], int): + # Single example: return a single bool (drop if all -100) + return any(label != -100 for label in labels) # Batched: 'labels' is a list of lists # Return a list of booleans, one per sub-list - results = [np.any(row_labels != -100) for row_labels in labels] - return results + results = [ + any(label != -100 for label in row_labels) # per-example drop logic + for row_labels in labels + ] + return results– Verified only these two calls of
np.any(… != -100)exist in the codebase.
– Switching to Python’s built-inany()ensures correct element-wise truth evaluation on lists.
🧹 Nitpick comments (5)
src/axolotl/datasets.py (4)
4-6: Fix minor grammar in the docstring.Use the contraction “Let’s”.
Apply this diff:
- We want this to be a wrapper for an existing dataset that we have loaded. Lets use the + We want this to be a wrapper for an existing dataset that we have loaded. Let's use the
46-46: Narrow the return type or document the expectation.process() advertises Dataset | IterableDataset, but init immediately accesses .data, which exists only for Dataset. Either narrow the return type to Dataset for this path or document that TokenizedPromptDataset must not be constructed with an IterableDataset.
48-51: Guard against Features.keys() returning a view; pass a concrete list to remove_columns.Hugging Face APIs expect str or list[str] for remove_columns. Using a keys view can be brittle.
Apply this diff:
- features = None - if not isinstance(dataset, IterableDataset): - features = dataset.features.keys() + features = None + if not isinstance(dataset, IterableDataset): + features = list(dataset.features.keys())Also applies to: 75-80
53-56: Make batch size configurable (avoid hard-coded 1_000).Different tokenizers/sequences need different batching to avoid OOM or underutilization. Read an optional batch_size from the tokenizer, with a sane default.
Apply this diff:
- if self.prompt_tokenizer.supports_batched: - map_kwargs["batched"] = True - map_kwargs["batch_size"] = 1_000 + if self.prompt_tokenizer.supports_batched: + map_kwargs["batched"] = True + map_kwargs["batch_size"] = getattr(self.prompt_tokenizer, "batch_size", 1_000)src/axolotl/utils/trainer.py (1)
621-633: Avoid initializing NCCL for single-process CPU/MPS runs.Gate the deepspeed comm init on CUDA availability to prevent crashes on non-CUDA setups.
Apply:
- if ( - int(os.environ.get("WORLD_SIZE", "1")) == 1 - and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1" - ): + if ( + int(os.environ.get("WORLD_SIZE", "1")) == 1 + and os.environ.get("AXOLOTL_IS_PREPROCESS", "0") != "1" + and torch.cuda.is_available() + ):
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
src/axolotl/datasets.py(4 hunks)src/axolotl/utils/trainer.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/axolotl/utils/trainer.py (1)
src/axolotl/core/builders/base.py (2)
train_dataset(81-82)train_dataset(85-86)
src/axolotl/datasets.py (2)
src/axolotl/prompt_strategies/chat_template.py (1)
tokenize_prompt(391-416)src/axolotl/prompt_tokenizers.py (3)
tokenize_prompt(66-67)tokenize_prompt(118-144)tokenize_prompt(259-288)
🪛 Ruff (0.12.2)
src/axolotl/utils/trainer.py
43-43: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
59-59: Use key in dict instead of key in dict.keys()
Remove .keys()
(SIM118)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: pre-commit
- GitHub Check: pre-commit
- GitHub Check: preview
🔇 Additional comments (3)
src/axolotl/datasets.py (1)
1-7: Docstring cleanup looks good.Thanks for removing references to ConstantLengthDataset and clarifying the middleware-style approach and padding via collators.
src/axolotl/utils/trainer.py (2)
344-357: LGTM: Custom filter path for train IterableDatasets avoids features/caching pitfalls.This is the right direction for streaming. With the mask-fix above, it should behave correctly for batched filtering.
542-545: Confirm no collation issues when including thelengthcolumn in the temporary DataLoader.Since this DataLoader is only used to compute
len(data_loader), it won’t iterate examples; still, confirm no side effects (e.g., accidental reuse elsewhere where a collator might choke onlength). If issues surface, swap back to a.remove_columns(["length"])view specifically for this estimator path.
There was a problem hiding this comment.
Actionable comments posted: 0
♻️ Duplicate comments (1)
src/axolotl/datasets.py (1)
97-106: Risk: not removing source columns for streaming can break model.forward.When features is None (common for streaming), remove_columns stays None; original columns persist and often get forwarded to model(**batch), causing unexpected-arg errors. This matches a prior maintainer note about needing to remove extra columns to avoid forward errors.
Two robust options:
- Option A (fail fast, preferred): require features for streaming datasets so remove_columns is known and safe.
- Option B (soft fallback): warn loudly when features is None and encourage providing features at load_dataset(..., streaming=True, features=...); still proceed but make the risk explicit.
Minimal changes for Option A:
- remove_columns = None - if dataset.features is not None: - remove_columns = list(dataset.features.keys()) + remove_columns: list[str] | None = None + if dataset.features is not None: + remove_columns = list(dataset.features.keys()) + else: + # Failing fast avoids silent training/runtime errors later. + raise ValueError( + "IterableDataset.features is None; cannot determine remove_columns. " + "Please create the streaming dataset with explicit `features=...` so " + "we can drop source columns after tokenization." + )If you’re not ready to hard-fail, consider this soft fallback:
- remove_columns = None - if dataset.features is not None: - remove_columns = list(dataset.features.keys()) + remove_columns = ( + list(dataset.features.keys()) if dataset.features is not None else None + ) + if remove_columns is None: + LOG.warning( + "IterableDataset.features is None; proceeding without remove_columns. " + "Upstream code MUST ensure only model-expected keys are forwarded, or " + "provide `features` when building the streaming dataset." + )Recommendation: also update dataset loading sites to pass explicit Features for streaming datasets so remove_columns can be computed deterministically.
🧹 Nitpick comments (8)
src/axolotl/datasets.py (8)
1-7: Docstring nit: fix contraction and tighten wording.Minor grammar and clarity improvements. Also keeps terminology consistent with “middleware-style wrappers”.
- We want this to be a wrapper for an existing dataset that we have loaded. Lets use the + We want this to be a wrapper for an existing dataset that we have loaded. Let's use the concept of middlewares to wrap each dataset. We'll use the collators later on to pad the datasets.
20-28: Class docstring is misleading for non-iterable path.This class is only constructed for non-Iterable datasets (Iterable path returns a mapped dataset directly). The docstring saying “stream of text files” can confuse readers.
-"""Dataset that returns tokenized prompts from a stream of text files. +"""Arrow-backed dataset wrapper that returns tokenized prompts. + +Note: Streaming (IterableDataset) is handled in wrap_dataset_for_tokenized_prompt +and does not instantiate this class.
46-51: Prefer column_names and ensure a concrete list for remove_columns.Using dataset.column_names is the canonical, library-supported way and returning a list avoids surprises when passed into remove_columns.
- features = None - if not isinstance(dataset, IterableDataset): - features = dataset.features.keys() + features: list[str] | None = None + if not isinstance(dataset, IterableDataset): + # Stable, explicit list of original columns for remove_columns + features = list(dataset.column_names)
61-68: Guard filter_rows with callable check and avoid truthiness pitfalls.If filter_rows exists but isn’t callable (e.g., set to True/False), dataset.filter will error. Use callable() and bind the name once.
- if ( - hasattr(self.prompt_tokenizer, "filter_rows") - and self.prompt_tokenizer.filter_rows - ): - filter_kwargs: dict[str, Any] = {"desc": "Strategy Filtering Rows"} + filter_fn = getattr(self.prompt_tokenizer, "filter_rows", None) + if callable(filter_fn): + filter_kwargs: dict[str, Any] = {"desc": "Strategy Filtering Rows"} if not isinstance(dataset, IterableDataset): filter_kwargs["num_proc"] = self.process_count dataset = dataset.filter( - self.prompt_tokenizer.filter_rows, + filter_fn, **filter_kwargs, )
70-80: Consolidate map kwargs; ensure immutability and future extensibility.Merging dicts inline is fine, but this block is duplicated with the streaming path below. Consider extracting a small helper that builds map kwargs consistently for both paths.
Follow-up: If you prefer an inline tweak now, you can leave as-is and address deduplication when touching wrap_dataset_for_tokenized_prompt.
92-99: Match batching semantics with non-streaming path (also add desc).For IterableDataset you set batched=True but omit batch_size and desc. This can affect throughput and observability.
if isinstance(dataset, IterableDataset): - map_kwargs = {} - if prompt_tokenizer.supports_batched: - map_kwargs["batched"] = True + map_kwargs: dict[str, Any] = {"desc": "Tokenizing Prompts"} + if prompt_tokenizer.supports_batched: + map_kwargs.update({"batched": True, "batch_size": 1_000})
92-109: Reduce duplication: share map option building between paths.Both branches now assemble highly similar map kwargs. Extract a small helper to keep behavior in lockstep for IterableDataset and Dataset (batching, desc, etc.).
Example shape (outside this function):
def _build_map_kwargs(prompt_tokenizer, *, process_count=None, keep_in_memory=None, is_iterable=False): mk = {"desc": "Tokenizing Prompts"} if prompt_tokenizer.supports_batched: mk.update({"batched": True, "batch_size": 1_000}) if not is_iterable: mk.update({ "num_proc": process_count, "keep_in_memory": keep_in_memory, }) return mkThen use _build_map_kwargs(...) in both process(...) and wrap_dataset_for_tokenized_prompt(...).
9-17: Unused logger import; consider emitting actionable warnings.LOG is currently unused. If you adopt the soft fallback for streaming remove_columns, emit a LOG.warning there; otherwise drop the import.
-from axolotl.utils.logging import get_logger +from axolotl.utils.logging import get_logger ... -LOG = get_logger(__name__) +LOG = get_logger(__name__) # keep only if warning/error is logged for IterableDataset
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
src/axolotl/datasets.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/axolotl/datasets.py (4)
src/axolotl/prompt_strategies/chat_template.py (1)
tokenize_prompt(391-416)src/axolotl/prompt_tokenizers.py (3)
tokenize_prompt(66-67)tokenize_prompt(118-144)tokenize_prompt(259-288)src/axolotl/prompt_strategies/completion.py (1)
tokenize_prompt(42-60)src/axolotl/prompt_strategies/pretrain.py (1)
tokenize_prompt(48-49)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (9)
- GitHub Check: pre-commit
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.0)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: pre-commit
- GitHub Check: PyTest (3.11, 2.7.0)
- GitHub Check: preview
🔇 Additional comments (1)
src/axolotl/datasets.py (1)
46-84: Add tests to safeguard streaming vs non-streaming column removalI’ve confirmed that:
load_dataset(..., streaming=True)is only used insrc/axolotl/utils/data/sft.py, without an explicitfeatures=argument, so all streaming datasets will hit theIterableDatasetbranch inprocess()(noremove_columns) [see scan results].- There are no code paths forwarding entire batches via
model(**batch), so unexpected columns would bubble through unless filtered out.- The existing
remove_columnslogic insrc/axolotl/datasets.py(lines 46–84) only applies to non-streaming (Dataset) inputs, while the streaming branch currently leaves original columns intact.To lock in the intended behavior and prevent regressions, please add tests covering:
IterableDataset with known schema
– Construct anIterableDataset(features=...)with an extra dummy column.
– Afterprocess(), assert that only the tokenizer’s output keys (e.g.input_ids,attention_mask) remain.IterableDataset without initial
features
– Use a raw streaming dataset (no.features).
– Verify that either:
- it raises a clear error when first encountering unexpected columns, or
- it emits a warning and still only returns the tokenizer’s output keys.
Non-streaming (
Dataset) path
– Map over a regularDatasetwith multiple columns.
– Confirmremove_columnsdrops the originals.
– Confirm batching is enabled withbatch_size=1000whensupports_batched=True.These tests will ensure the code behaves correctly for both streaming and non-streaming datasets and guard against future changes.
| return merged_dataset | ||
|
|
||
|
|
||
| def _merge_datasets_with_strategy( |
| yield sample | ||
|
|
||
| # Create new IterableDataset from the filtered generator | ||
| filtered_dataset = IterableDataset.from_generator(filtered_generator) |
There was a problem hiding this comment.
| filtered_dataset = IterableDataset.from_generator(filtered_generator) | |
| filtered_dataset = IterableDataset.from_generator(filtered_generator, gen_kwargs={"batch_size": 1000}) |
batch_size as a kwarg feels better than a constant
winglian
left a comment
There was a problem hiding this comment.
Just a couple of nits. lgtm. Were you able to do any loss curve and step count comparisons to standard SFT w/o streaming?
Yes, everything looked okay except for sample packing + streaming, will need to fix. |
|
Closing in favor of #3101. |

Description
Streaming datasets should not be limited to pretraining only. This PR changes that, and also adds support for different sampling strategies (round robin and sampling according to weights). Also, preprocessing iterable datasets has been disallowed by this PR.
Motivation and Context
Some folks have been asking for better dataset streaming support. This is part of a larger effort to refactor / improve the logic around how we handle datasets / dataloading.
How has this been tested?
Summary by CodeRabbit
New Features
Bug Fixes / Validation
Documentation
Refactor
Tests
Chores