feat:add support dataset_num_processes#3071
Conversation
📝 WalkthroughWalkthroughRenames the dataset parallelism configuration from dataset_processes to dataset_num_proc across config, environment, code paths, and docs; adds AXOLOTL_DATASET_NUM_PROC handling. Updates builders/trainers/data utilities to use the new key. Introduces ConstantLengthDataset in datasets.py and adds type annotations. Updates tests accordingly. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~30 minutes Possibly related PRs
Suggested labels
Suggested reviewers
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
There was a problem hiding this comment.
Actionable comments posted: 3
🔭 Outside diff range comments (3)
src/axolotl/utils/data/rl.py (1)
114-120: Guard against num_proc == 0 for HF Datasets.map — fix requiredShort: The config validator only fills a default when dataset_num_proc is None (src/axolotl/utils/schemas/config.py), so a user-specified 0 will pass through and can cause HF Datasets to error (datasets.map(num_proc=0)). Coerce falsy 0 → None at call sites (or normalize centrally).
Change to apply (RL file example):
dataset = dataset.map( ds_transform_fn, - num_proc=cfg.dataset_num_proc, + num_proc=(cfg.dataset_num_proc or None), load_from_cache_file=not cfg.is_preprocess, desc="Mapping RL Dataset", **map_kwargs, )Places that need the same attention (please update to use (cfg.dataset_num_proc or None) or centralize normalization):
- src/axolotl/utils/data/rl.py — map at ~lines 114–120 and filter at ~lines 234–239
- src/axolotl/utils/trainer.py — filter_map_kwargs assignment (line ~279) and map calls at ~319, ~336, ~345
- src/axolotl/utils/data/utils.py — filter_map_kwargs assignment (~line 188)
- src/axolotl/utils/data/wrappers.py — dataset_kwargs["process_count"] (~line 84) — verify whether None is accepted by that wrapper; otherwise default to get_default_process_count()
- src/axolotl/core/trainers/base.py — num_processes=self.args.dataset_num_proc (~line 123) — verify downstream API accepts 0
Note: src/axolotl/utils/data/shared.py already guards with
cfg.dataset_num_proc or get_default_process_count().Recommendation: either (A) update all HF datasets.map/filter call sites to pass (cfg.dataset_num_proc or None), or (B) update the config validator (default_dataset_num_proc) to treat 0 as unset (convert 0 → None) so the value never reaches callers as 0.
src/axolotl/utils/datasets.py (1)
6-13: Harden env var parsing in get_default_process_count — fix requiredCurrent implementation can raise ValueError for non-integer env values and may return None if os.cpu_count() is None. I verified call sites and occurrences; please apply the following change.
Files to update / review:
- src/axolotl/utils/datasets.py — update get_default_process_count() as shown below.
- Call sites that rely on this behavior:
- src/axolotl/utils/schemas/config.py:1282 — data["dataset_num_proc"] = get_default_process_count()
- src/axolotl/utils/data/shared.py:414 — num_workers = cfg.dataset_num_proc or get_default_process_count()
- Env var occurrences (CI / packaging):
- cicd/single_gpu.py:69-70 — sets AXOLOTL_DATASET_NUM_PROC / AXOLOTL_DATASET_PROCESSES
- cicd/Dockerfile.jinja:12-13 — ENV AXOLOTL_DATASET_NUM_PROC / AXOLOTL_DATASET_PROCESSES
- Deprecated key handling:
- src/axolotl/utils/schemas/config.py:225 and 1268–1274 — dataset_processes is deprecated and mapped to dataset_num_proc
Suggested patch:
def get_default_process_count(): - if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"): - return int(axolotl_dataset_num_proc) - if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"): - return int(axolotl_dataset_processes) - if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"): - return int(runpod_cpu_count) - return os.cpu_count() + def _parse_env_int(name): + val = os.environ.get(name) + if val is None: + return None + try: + parsed = int(val) + return parsed if parsed > 0 else None + except ValueError: + return None + + for key in ( + "AXOLOTL_DATASET_NUM_PROC", + "AXOLOTL_DATASET_PROCESSES", + "RUNPOD_CPU_COUNT", + ): + parsed = _parse_env_int(key) + if parsed is not None: + return parsed + + # Ensure a sensible minimum of 1 + return max(os.cpu_count() or 1, 1)Also add a short docstring to get_default_process_count describing precedence:
AXOLOTL_DATASET_NUM_PROC > AXOLOTL_DATASET_PROCESSES > RUNPOD_CPU_COUNT > os.cpu_count().src/axolotl/utils/data/shared.py (1)
414-424: Rename looks good; sanitize num_workers to avoid 0/None edge casesThe switch to cfg.dataset_num_proc is consistent. However, if dataset_num_proc is 0 (misconfigured) or get_default_process_count() returns None (rare, but possible), num_workers will become 0/None and cause failures:
- IterableDataset branch: range(num_workers) and num_proc=num_workers will error/behave unexpectedly.
- Non-iterable branch: min(..., num_workers) can compute to 0 and pass num_proc=0 to save_to_disk, which is invalid in HF datasets.
Clamp num_workers to at least 1 before use.
Apply this diff:
- num_workers = cfg.dataset_num_proc or get_default_process_count() + num_workers = cfg.dataset_num_proc or get_default_process_count() + # Ensure valid worker count (HF datasets expects None or >=1; treat 0/None as 1 here) + try: + num_workers = int(num_workers) if num_workers is not None else None + except (TypeError, ValueError): + num_workers = None + if not num_workers or num_workers < 1: + num_workers = 1
🧹 Nitpick comments (7)
cicd/single_gpu.py (1)
69-71: Keep both envs in sync and centralize the defaultMinor improvement: derive the value once (respecting any pre-set environment) and assign it to both keys. This avoids drift if one default changes.
Apply this diff:
- sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8" - sp_env["AXOLOTL_DATASET_PROCESSES"] = "8" + num_proc = ( + os.environ.get("AXOLOTL_DATASET_NUM_PROC") + or os.environ.get("AXOLOTL_DATASET_PROCESSES") + or "8" + ) + sp_env["AXOLOTL_DATASET_NUM_PROC"] = num_proc + # Deprecated; kept for transition. Remove after the deprecation window. + sp_env["AXOLOTL_DATASET_PROCESSES"] = num_procsrc/axolotl/utils/data/rl.py (1)
235-241: Mirror the same None/0 guard for filter’s num_procSame rationale as above for map: filter(num_proc=0) is risky.
Apply this diff:
split_datasets[i] = split_datasets[i].filter( drop_long, - num_proc=cfg.dataset_num_proc, + num_proc=(cfg.dataset_num_proc or None), load_from_cache_file=not cfg.is_preprocess, desc="Dropping Long Sequences", )src/axolotl/utils/data/utils.py (1)
186-191: Coerce falsy dataset_num_proc to None to avoid HF errorsAlign behavior with a safe default: ensure num_proc is None when unset or 0.
Apply this diff:
filter_map_kwargs = {} if not isinstance(dataset, IterableDataset): - filter_map_kwargs["num_proc"] = cfg.dataset_num_proc + filter_map_kwargs["num_proc"] = (cfg.dataset_num_proc or None) filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocessdocs/debugging.qmd (1)
32-32: Add mention of the new env var for completenessSince AXOLOTL_DATASET_NUM_PROC is now supported, it’s useful to surface it here alongside config/CLI.
Apply this diff:
- - Set `dataset_num_proc: 1` in your axolotl config or run the training command with `--dataset_num_proc=1`. + - Set `dataset_num_proc: 1` in your axolotl config, export `AXOLOTL_DATASET_NUM_PROC=1`, or run the training command with `--dataset_num_proc=1`.tests/test_datasets.py (1)
144-145: Rename to dataset_num_proc across dataset tests — LGTM; consider adding a deprecated-key coverage testRan the grep you provided over tests.
Findings:
- "dataset_num_proc" is used across tests (examples: tests/test_packed_dataset.py, tests/test_exact_deduplication.py, tests/test_datasets.py, tests/e2e/test_llama_pretrain.py, tests/core/test_builders.py, tests/e2e/patched/test_activation_checkpointing.py).
- "dataset_processes" (deprecated) has 0 occurrences in tests.
Recommendation (optional but useful): add a small unit test (config-normalization scope) that:
- sets only dataset_processes and asserts it maps to cfg.dataset_num_proc;
- sets both keys with different values and asserts dataset_num_proc wins and a deprecation warning is emitted.
I can draft that minimal test if you want.
src/axolotl/utils/trainer.py (1)
279-279: Optional: Avoid setting num_proc key when it’s NoneHF datasets treats num_proc=None as single-process. You can simplify by only setting the key if cfg.dataset_num_proc is truthy to reduce kwargs noise.
- filter_map_kwargs["num_proc"] = cfg.dataset_num_proc + if cfg.dataset_num_proc: + filter_map_kwargs["num_proc"] = cfg.dataset_num_procsrc/axolotl/utils/schemas/config.py (1)
226-235: Deprecation field is clear; consider marking JSON schema as deprecated for toolingThe deprecation messaging is good. For better JSON schema/OpenAPI tooling support, set a boolean deprecated flag in json_schema_extra as well.
dataset_processes: int | None = Field( default=None, deprecated="Use `dataset_num_proc` instead. This parameter will be removed in a future version.", json_schema_extra={ "description": ( "DEPRECATED: Use `dataset_num_proc` instead. " "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n" "For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT." ), + "deprecated": True, }, )
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (18)
cicd/Dockerfile.jinja(1 hunks)cicd/single_gpu.py(1 hunks)devtools/dev_chat_template.yml(1 hunks)docs/debugging.qmd(2 hunks)src/axolotl/core/builders/base.py(1 hunks)src/axolotl/utils/data/rl.py(2 hunks)src/axolotl/utils/data/shared.py(1 hunks)src/axolotl/utils/data/utils.py(1 hunks)src/axolotl/utils/data/wrappers.py(1 hunks)src/axolotl/utils/datasets.py(1 hunks)src/axolotl/utils/schemas/config.py(2 hunks)src/axolotl/utils/trainer.py(5 hunks)tests/core/test_builders.py(2 hunks)tests/e2e/patched/test_activation_checkpointing.py(1 hunks)tests/e2e/test_llama_pretrain.py(1 hunks)tests/test_datasets.py(7 hunks)tests/test_exact_deduplication.py(1 hunks)tests/test_packed_dataset.py(1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (8)
src/axolotl/utils/data/utils.py (2)
tests/test_exact_deduplication.py (1)
cfg(201-216)src/axolotl/integrations/base.py (2)
cfg(352-353)cfg(356-357)
src/axolotl/core/builders/base.py (2)
tests/test_exact_deduplication.py (1)
cfg(201-216)src/axolotl/integrations/base.py (2)
cfg(352-353)cfg(356-357)
tests/core/test_builders.py (1)
tests/test_exact_deduplication.py (1)
cfg(201-216)
src/axolotl/utils/data/wrappers.py (1)
tests/test_exact_deduplication.py (1)
cfg(201-216)
src/axolotl/utils/data/shared.py (2)
tests/test_exact_deduplication.py (1)
cfg(201-216)src/axolotl/utils/datasets.py (1)
get_default_process_count(6-13)
src/axolotl/utils/trainer.py (2)
tests/test_exact_deduplication.py (1)
cfg(201-216)src/axolotl/integrations/base.py (2)
cfg(352-353)cfg(356-357)
src/axolotl/utils/data/rl.py (2)
tests/test_exact_deduplication.py (1)
cfg(201-216)src/axolotl/integrations/base.py (2)
cfg(352-353)cfg(356-357)
src/axolotl/utils/schemas/config.py (1)
src/axolotl/utils/datasets.py (1)
get_default_process_count(6-13)
🔇 Additional comments (12)
devtools/dev_chat_template.yml (1)
16-16: Renamed key aligns with the new canonical configSwitching to dataset_num_proc here matches the updated schema and docs. Looks good.
tests/e2e/patched/test_activation_checkpointing.py (1)
73-74: Test config updated to the new keyUsing dataset_num_proc in this test fixture is consistent with the rename and should exercise the new code path.
src/axolotl/utils/data/wrappers.py (1)
83-86: Approved — cfg.dataset_num_proc is correct; only deprecated alias remainsSearch shows only deprecation/alias handling for
dataset_processes; the wrapper change is safe.
- src/axolotl/utils/schemas/config.py:225 —
dataset_processesField with deprecated="Usedataset_num_procinstead."- src/axolotl/utils/schemas/config.py:1268–1280 —
default_dataset_num_proccopiesdataset_processes→dataset_num_procand logs deprecation warnings.- src/axolotl/utils/data/wrappers.py:83–86 — uses
cfg.dataset_num_procforprocess_count(change under review).tests/e2e/test_llama_pretrain.py (1)
33-34: LGTM: renamed key to dataset_num_procConsistent with the PR-wide rename; keeps behavior unchanged (value=1).
tests/test_exact_deduplication.py (1)
213-214: LGTM: test fixture key renameMatches the new canonical config key and maintains the same semantics (4 workers).
tests/test_packed_dataset.py (1)
102-102: Renamed key to dataset_num_proc — LGTMConfig rename in the test input matches the repo-wide migration and keeps the same semantics.
tests/core/test_builders.py (1)
85-85: Consistent switch to dataset_num_proc in builders tests — LGTMBoth the base fixture and the per-test override now use the canonical key; aligns with updated builders reading cfg.dataset_num_proc.
Also applies to: 446-446
docs/debugging.qmd (1)
104-104: CLI args updated to dataset_num_proc — LGTMDocs reflect the new canonical flag in the VSCode launch example.
src/axolotl/utils/data/shared.py (1)
433-440: Prevent passing num_proc=0 to save_to_diskWith the clamped num_workers above, this becomes safe. Without it, the expression can yield 0 and break. No change needed here if you adopt the clamp; otherwise, consider guarding num_workers to be >=1.
src/axolotl/utils/trainer.py (1)
279-279: Consistent rename to dataset_num_proc — ensure value is >= 1 everywhereAll updates to use cfg.dataset_num_proc look correct and consistent. One behavioral edge case: if cfg.dataset_num_proc is 0 (or negative), HF datasets and MultipackBatchSampler will misbehave or error. In shared.py you already fall back via “or get_default_process_count()”, but here you pass the value through directly.
Prefer centralizing a clamp at config load (recommended below), or locally coalescing to None/1.
If you don’t adopt the config-level clamp, a minimal local guard would look like:
- filter_map_kwargs["num_proc"] = cfg.dataset_num_proc + filter_map_kwargs["num_proc"] = ( + None if cfg.dataset_num_proc in (None, 0) else max(1, cfg.dataset_num_proc) + )Same applies to:
- Line 319: num_proc in map()
- Line 336: num_proc in PoSE map()
- Line 345: num_proc in eval PoSE map()
- Line 470: num_processes in MultipackBatchSampler
Also applies to: 319-319, 336-336, 345-345, 470-472
src/axolotl/utils/schemas/config.py (2)
17-24: Confirm that get_default_process_count() behavior is as intendedThis validator relies on get_default_process_count() which now prioritizes AXOLOTL_DATASET_NUM_PROC over AXOLOTL_DATASET_PROCESSES and then RUNPOD_CPU_COUNT. That’s consistent with the PR goals. Just ensure documentation mentions the new env var and that “0” is not a valid value (after the clamp).
Would you like me to update docs to explicitly state:
- dataset_num_proc must be >= 1
- AXOLOTL_DATASET_NUM_PROC is preferred over AXOLOTL_DATASET_PROCESSES
- If neither is set, defaults to os.cpu_count() (or 1 if undetermined)?
1-1: Repo-wide check — OK: only compat-layer usages ofdataset_processesremainSearch found references only in the compatibility code inside src/axolotl/utils/schemas/config.py; no other occurrences in the repo.
- src/axolotl/utils/schemas/config.py — line ~225: schema field
dataset_processes: int | None = Field(..., deprecated="Use \dataset_num_proc` instead...")`- src/axolotl/utils/schemas/config.py — lines ~1268–1279:
default_dataset_num_procclassmethod that copiesdataset_processes->dataset_num_procand logs deprecation warningsNo further action required.
| ENV AXOLOTL_DATASET_NUM_PROC="8" | ||
| ENV AXOLOTL_DATASET_PROCESSES="8" |
There was a problem hiding this comment.
💡 Verification agent
🧩 Analysis chain
Add deprecation note or plan for the legacy env var
Setting both AXOLOTL_DATASET_NUM_PROC and AXOLOTL_DATASET_PROCESSES is fine for a transition period. Consider annotating the legacy variable to avoid future confusion and to make the removal plan explicit.
Apply this diff to add a brief deprecation note:
ENV AXOLOTL_DATASET_NUM_PROC="8"
-ENV AXOLOTL_DATASET_PROCESSES="8"
+# Deprecated: kept for backward-compat. Prefer AXOLOTL_DATASET_NUM_PROC; remove after deprecation window.
+ENV AXOLOTL_DATASET_PROCESSES="8"Run this to verify the precedence in code (NUM_PROC should be preferred over PROCESSES):
🏁 Script executed:
#!/bin/bash
# Expect: A code path that reads AXOLOTL_DATASET_NUM_PROC first, then falls back to AXOLOTL_DATASET_PROCESSES.
rg -nP -C3 'AXOLOTL_DATASET_(NUM_PROC|PROCESSES)' src | sed -n '1,200p'
# Also check for any lingering config key usages of `dataset_processes` outside of deprecation/alias handling.
rg -nP -C2 '\bdataset_processes\b' | sed -n '1,200p'Length of output: 2466
Annotate legacy AXOLOTL_DATASET_PROCESSES in cicd/Dockerfile.jinja
Verified: AXOLOTL_DATASET_NUM_PROC is preferred at runtime and dataset_processes is already deprecated/migrated in the config — please add a deprecation note in the Dockerfile to avoid confusion.
Files to note:
- cicd/Dockerfile.jinja — lines 12–13: add deprecation comment for the legacy env var.
- Verified precedence and deprecation handling in:
- src/axolotl/utils/datasets.py — get_default_process_count() checks AXOLOTL_DATASET_NUM_PROC then falls back to AXOLOTL_DATASET_PROCESSES (lines ~7–11).
- src/axolotl/utils/schemas/config.py — dataset_processes is declared deprecated (line ~225) and migrated into dataset_num_proc with warnings (lines ~1268–1279).
Apply this diff:
ENV AXOLOTL_DATASET_NUM_PROC="8"
-ENV AXOLOTL_DATASET_PROCESSES="8"
+# Deprecated: kept for backward-compat. Prefer AXOLOTL_DATASET_NUM_PROC; remove after deprecation window.
+ENV AXOLOTL_DATASET_PROCESSES="8"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ENV AXOLOTL_DATASET_NUM_PROC="8" | |
| ENV AXOLOTL_DATASET_PROCESSES="8" | |
| ENV AXOLOTL_DATASET_NUM_PROC="8" | |
| # Deprecated: kept for backward-compat. Prefer AXOLOTL_DATASET_NUM_PROC; remove after deprecation window. | |
| ENV AXOLOTL_DATASET_PROCESSES="8" |
🤖 Prompt for AI Agents
In cicd/Dockerfile.jinja around lines 12–13, the legacy
AXOLOTL_DATASET_PROCESSES env var is set without any deprecation note; add a
concise comment next to that ENV line stating that AXOLOTL_DATASET_NUM_PROC is
the preferred runtime variable, AXOLOTL_DATASET_PROCESSES is deprecated and kept
only for backward compatibility, and that code now prefers
AXOLOTL_DATASET_NUM_PROC (see src/axolotl/utils/datasets.py and
src/axolotl/utils/schemas/config.py for migration/deprecation handling). Ensure
the comment is short, mentions backward compatibility, and does not remove the
existing ENV assignment.
Co-authored-by: Dan Saunders <danjsaund@gmail.com>
| if self.cfg.dataset_num_proc: | ||
| training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_num_proc |
There was a problem hiding this comment.
Could add it to the above list instead of separate if condition here
| else: | ||
| LOG.warning( | ||
| "Both dataset_processes and dataset_num_proc are set. " | ||
| "Using dataset_num_proc and ignoring dataset_processes." | ||
| ) |
There was a problem hiding this comment.
Should probably raise an error as it's a conflict?
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
src/axolotl/utils/schemas/config.py (3)
226-235: Use boolean JSON Schema deprecation flag; keep explanation in description
deprecated=should be a boolean in JSON Schema. Move it underjson_schema_extraasdeprecated: Trueand avoid duplicating the message.- dataset_processes: int | None = Field( - default=None, - deprecated="Use `dataset_num_proc` instead. This parameter will be removed in a future version.", - json_schema_extra={ - "description": ( - "DEPRECATED: Use `dataset_num_proc` instead. " - "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n" - "For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT." - ) - }, - ) + dataset_processes: int | None = Field( + default=None, + json_schema_extra={ + "deprecated": True, + "description": ( + "DEPRECATED: Use `dataset_num_proc` instead. " + "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n" + "For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT." + ), + }, + )
236-244: Enforce >=1 at schema level and document env override orderAdd
ge=1and note the override precedence including the new env var.- dataset_num_proc: int | None = Field( - default=None, - json_schema_extra={ - "description": ( - "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n" - "For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT." - ) - }, - ) + dataset_num_proc: int | None = Field( + default=None, + ge=1, + json_schema_extra={ + "description": ( + "The maximum number of processes to use while preprocessing your input dataset. " + "Override/default precedence: AXOLOTL_DATASET_NUM_PROC > AXOLOTL_DATASET_PROCESSES > RUNPOD_CPU_COUNT > os.cpu_count()." + ) + }, + )
236-244: Supportdataset_num_processesalias or update docs: the PR title/docs refer todataset_num_processesbut the schema only acceptsdataset_num_proc; either addvalidation_alias=AliasChoices("dataset_num_proc","dataset_num_processes")on that Field or align the title/docs accordingly.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
src/axolotl/core/builders/base.py(1 hunks)src/axolotl/utils/schemas/config.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/axolotl/core/builders/base.py
🧰 Additional context used
🧬 Code graph analysis (1)
src/axolotl/utils/schemas/config.py (1)
src/axolotl/utils/datasets.py (1)
get_default_process_count(6-13)
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/config.py
1276-1279: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
- GitHub Check: preview
| def default_dataset_num_proc(cls, data): | ||
| if data.get("dataset_processes") is not None: | ||
| if data.get("dataset_num_proc") is None: | ||
| data["dataset_num_proc"] = data["dataset_processes"] | ||
| LOG.warning( | ||
| "dataset_processes is deprecated and will be removed in a future version. " | ||
| "Please use dataset_num_proc instead." | ||
| ) | ||
| else: | ||
| raise ValueError( | ||
| "Both dataset_processes and dataset_num_proc are set. " | ||
| "Using dataset_num_proc and ignoring dataset_processes." | ||
| ) | ||
|
|
||
| if data.get("dataset_num_proc") is None: | ||
| data["dataset_num_proc"] = get_default_process_count() | ||
|
|
||
| return data |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Harden defaulting/parsing, clamp to >=1, and fix conflict error message
- Current flow can yield 0/None or non-int, leading to runtime errors downstream.
- Conflict error message says it will ignore one value but raises instead—make the message consistent.
@model_validator(mode="before")
@classmethod
def default_dataset_num_proc(cls, data):
- if data.get("dataset_processes") is not None:
- if data.get("dataset_num_proc") is None:
- data["dataset_num_proc"] = data["dataset_processes"]
- LOG.warning(
- "dataset_processes is deprecated and will be removed in a future version. "
- "Please use dataset_num_proc instead."
- )
- else:
- raise ValueError(
- "Both dataset_processes and dataset_num_proc are set. "
- "Using dataset_num_proc and ignoring dataset_processes."
- )
-
- if data.get("dataset_num_proc") is None:
- data["dataset_num_proc"] = get_default_process_count()
+ # Handle deprecated -> canonical
+ if data.get("dataset_processes") is not None and data.get("dataset_num_proc") is None:
+ data["dataset_num_proc"] = data["dataset_processes"]
+ LOG.warning(
+ "dataset_processes is deprecated and will be removed in a future version. "
+ "Please use dataset_num_proc instead."
+ )
+ elif data.get("dataset_processes") is not None and data.get("dataset_num_proc") is not None:
+ # Conflict: be explicit and fail fast
+ raise ValueError(
+ "Conflicting config: dataset_processes and dataset_num_proc cannot both be set. "
+ "Remove dataset_processes."
+ )
+
+ # Fill default if unset, coerce and clamp to >=1
+ if data.get("dataset_num_proc") is None:
+ try:
+ default = get_default_process_count()
+ except (TypeError, ValueError):
+ default = None
+ data["dataset_num_proc"] = default if (isinstance(default, int) and default > 0) else 1
+ else:
+ try:
+ val = int(data["dataset_num_proc"])
+ except (TypeError, ValueError):
+ raise ValueError("dataset_num_proc must be an integer >= 1") from None
+ if val < 1:
+ LOG.warning("dataset_num_proc must be >= 1; clamping to 1.")
+ val = 1
+ data["dataset_num_proc"] = val📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def default_dataset_num_proc(cls, data): | |
| if data.get("dataset_processes") is not None: | |
| if data.get("dataset_num_proc") is None: | |
| data["dataset_num_proc"] = data["dataset_processes"] | |
| LOG.warning( | |
| "dataset_processes is deprecated and will be removed in a future version. " | |
| "Please use dataset_num_proc instead." | |
| ) | |
| else: | |
| raise ValueError( | |
| "Both dataset_processes and dataset_num_proc are set. " | |
| "Using dataset_num_proc and ignoring dataset_processes." | |
| ) | |
| if data.get("dataset_num_proc") is None: | |
| data["dataset_num_proc"] = get_default_process_count() | |
| return data | |
| @model_validator(mode="before") | |
| @classmethod | |
| def default_dataset_num_proc(cls, data): | |
| # Handle deprecated -> canonical | |
| if data.get("dataset_processes") is not None and data.get("dataset_num_proc") is None: | |
| data["dataset_num_proc"] = data["dataset_processes"] | |
| LOG.warning( | |
| "dataset_processes is deprecated and will be removed in a future version. " | |
| "Please use dataset_num_proc instead." | |
| ) | |
| elif data.get("dataset_processes") is not None and data.get("dataset_num_proc") is not None: | |
| # Conflict: be explicit and fail fast | |
| raise ValueError( | |
| "Conflicting config: dataset_processes and dataset_num_proc cannot both be set. " | |
| "Remove dataset_processes." | |
| ) | |
| # Fill default if unset, coerce and clamp to >=1 | |
| if data.get("dataset_num_proc") is None: | |
| try: | |
| default = get_default_process_count() | |
| except (TypeError, ValueError): | |
| default = None | |
| data["dataset_num_proc"] = default if (isinstance(default, int) and default > 0) else 1 | |
| else: | |
| try: | |
| val = int(data["dataset_num_proc"]) | |
| except (TypeError, ValueError): | |
| raise ValueError("dataset_num_proc must be an integer >= 1") from None | |
| if val < 1: | |
| LOG.warning("dataset_num_proc must be >= 1; clamping to 1.") | |
| val = 1 | |
| data["dataset_num_proc"] = val | |
| return data |
🧰 Tools
🪛 Ruff (0.12.2)
1276-1279: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In src/axolotl/utils/schemas/config.py around lines 1267-1284, harden parsing in
default_dataset_num_proc: when reading dataset_processes/dataset_num_proc coerce
values to int (raise ValueError if not convertible), prefer dataset_num_proc
when both are provided but do not raise — replace the current raise with a
LOG.warning that clearly states dataset_num_proc will be used and
dataset_processes ignored, ensure the final dataset_num_proc is clamped to at
least 1 (if <=0 set to 1 and LOG.warning about the clamp), and return the
sanitized data; this prevents None/0/non-int downstream errors and makes the
conflict message consistent with behavior.
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (5)
src/axolotl/datasets.py (5)
48-74: Use column_names instead of features.keys() for remove_columns.More robust across Dataset variants.
Apply:
- features = dataset.features.keys() + columns = dataset.column_names @@ - remove_columns=features, + remove_columns=columns,
76-92: Guard IterableDataset without features and pass remove_columns conditionally.Some streaming datasets don’t populate
features.Apply:
if isinstance(dataset, IterableDataset): map_kwargs = {} if prompt_tokenizer.supports_batched: map_kwargs["batched"] = True - features = list(dataset.features.keys()) - return dataset.map( - prompt_tokenizer.tokenize_prompt, - remove_columns=features, - **map_kwargs, - ) + feats = getattr(dataset, "features", None) + if feats is not None: + map_kwargs["remove_columns"] = list(feats.keys()) + return dataset.map(prompt_tokenizer.tokenize_prompt, **map_kwargs)
105-115: Set safe concat/pad tokens; avoid None EOS.Fallback to PAD when EOS is missing; also precompute
pad_token_id.Apply:
self.tokenizer = tokenizer - self.concat_token_id = tokenizer.eos_token_id + self.concat_token_id = ( + tokenizer.eos_token_id + if tokenizer.eos_token_id is not None + else tokenizer.pad_token_id + ) + self.pad_token_id = ( + tokenizer.pad_token_id + if tokenizer.pad_token_id is not None + else self.concat_token_id + ) self.datasets: list[IterableDataset] = datasets self.seq_length = seq_length
191-199: Avoid dropping long examples; chunk instead.Instead of skipping, split long records into
seq_length-sized shards (with optional overlap) to preserve data.I can provide a small helper to slice
example["input_ids"]/labelsinto multiple sub-examples if you want it in this PR.
94-104: Interleaving note.If multiple datasets are provided, consider round-robin interleaving to reduce domain burstiness.
I can sketch an iterator that cycles over
self.datasetsand yields one example at a time with a cap per-source.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
src/axolotl/datasets.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/axolotl/datasets.py (2)
src/axolotl/utils/logging.py (1)
get_logger(42-49)src/axolotl/prompt_tokenizers.py (1)
PromptTokenizingStrategy(43-105)
🔇 Additional comments (1)
src/axolotl/datasets.py (1)
32-47: Wiring ofdataset_num_proctoprocess_countis correct. Thedataset_kwargsin src/axolotl/utils/data/wrappers.py mapscfg.dataset_num_procinto theprocess_countparameter ofTokenizedPromptDataset, and all explicitTokenizedPromptDatasetinstantiations in tests likewise passprocess_count. No further changes needed.
| vocab_size = len(tokenizer.get_vocab()) | ||
|
|
||
| if vocab_size <= torch.iinfo(torch.int16).max: | ||
| self.tokens_dtype = torch.int16 | ||
| elif vocab_size <= torch.iinfo(torch.int32).max: | ||
| self.tokens_dtype = torch.int32 | ||
| else: | ||
| self.tokens_dtype = torch.int64 | ||
|
|
There was a problem hiding this comment.
Token dtype must be torch.long for model embeddings.
Using int16/int32 for token indices will break nn.Embedding and many HF models.
Apply:
- vocab_size = len(tokenizer.get_vocab())
-
- if vocab_size <= torch.iinfo(torch.int16).max:
- self.tokens_dtype = torch.int16
- elif vocab_size <= torch.iinfo(torch.int32).max:
- self.tokens_dtype = torch.int32
- else:
- self.tokens_dtype = torch.int64
+ # Embedding layers expect Long tensor indices; keep masks integral too.
+ self.tokens_dtype = torch.long
+ self.mask_dtype = torch.long📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| vocab_size = len(tokenizer.get_vocab()) | |
| if vocab_size <= torch.iinfo(torch.int16).max: | |
| self.tokens_dtype = torch.int16 | |
| elif vocab_size <= torch.iinfo(torch.int32).max: | |
| self.tokens_dtype = torch.int32 | |
| else: | |
| self.tokens_dtype = torch.int64 | |
| # Embedding layers expect Long tensor indices; keep masks integral too. | |
| self.tokens_dtype = torch.long | |
| self.mask_dtype = torch.long |
🤖 Prompt for AI Agents
In src/axolotl/datasets.py around lines 116-124, the code currently chooses
int16/int32/int64 based on vocab_size, but token indices must be torch.long
(int64) for nn.Embedding and HF models. Replace the branching logic so
self.tokens_dtype is always set to torch.long (torch.int64); update any
downstream casts/usages to use this dtype to ensure compatibility with
embeddings.
| buffer_len + int(add_concat_token) + example_len > self.seq_length | ||
| ): | ||
| if buffer["input_ids"]: | ||
| input_ids = torch.cat(buffer["input_ids"], dim=-1)[ | ||
| : self.seq_length | ||
| ] | ||
| attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[ | ||
| : self.seq_length | ||
| ] | ||
| position_ids = torch.cat(buffer["position_ids"], dim=-1)[ | ||
| : self.seq_length | ||
| ] | ||
| labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] | ||
| if labels.size() == input_ids.size() and ( | ||
| attention_mask.size() == input_ids.size() | ||
| ): | ||
| yield { | ||
| "input_ids": input_ids, | ||
| "labels": labels, | ||
| "attention_mask": attention_mask, | ||
| "position_ids": position_ids, | ||
| } | ||
| else: | ||
| LOG.warning( | ||
| "Dropping batch due to tensor size mismatch " | ||
| f"input_ids: {input_ids.size()}, " | ||
| f"labels: {labels.size()}, " | ||
| f"attention_mask: {attention_mask.size()}" | ||
| ) | ||
| buffer = { | ||
| "input_ids": [], |
There was a problem hiding this comment.
🛠️ Refactor suggestion
Guarantee constant-length output by padding/truncation at flush.
Currently the final batch can be shorter than seq_length, violating the class contract and risking collate failures.
Apply:
- if buffer["input_ids"]:
- input_ids = torch.cat(buffer["input_ids"], dim=-1)[
- : self.seq_length
- ]
- attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
- : self.seq_length
- ]
- position_ids = torch.cat(buffer["position_ids"], dim=-1)[
- : self.seq_length
- ]
- labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
- if labels.size() == input_ids.size() and (
- attention_mask.size() == input_ids.size()
- ):
- yield {
- "input_ids": input_ids,
- "labels": labels,
- "attention_mask": attention_mask,
- "position_ids": position_ids,
- }
- else:
- LOG.warning(
- "Dropping batch due to tensor size mismatch "
- f"input_ids: {input_ids.size()}, "
- f"labels: {labels.size()}, "
- f"attention_mask: {attention_mask.size()}"
- )
+ if buffer["input_ids"]:
+ input_ids = torch.cat(buffer["input_ids"], dim=-1)
+ attention_mask = torch.cat(buffer["attention_mask"], dim=-1)
+ labels = torch.cat(buffer["labels"], dim=-1)
+ cur_len = input_ids.size(-1)
+ # truncate if overflowed
+ if cur_len > self.seq_length:
+ input_ids = input_ids[: self.seq_length]
+ attention_mask = attention_mask[: self.seq_length]
+ labels = labels[: self.seq_length]
+ cur_len = self.seq_length
+ # pad if underfull
+ if cur_len < self.seq_length:
+ pad_len = self.seq_length - cur_len
+ input_ids = torch.cat(
+ [input_ids,
+ torch.full((pad_len,), self.pad_token_id, dtype=self.tokens_dtype)],
+ dim=-1,
+ )
+ attention_mask = torch.cat(
+ [attention_mask,
+ torch.zeros((pad_len,), dtype=self.mask_dtype)],
+ dim=-1,
+ )
+ labels = torch.cat(
+ [labels,
+ torch.full((pad_len,), -100, dtype=self.tokens_dtype)],
+ dim=-1,
+ )
+ # recompute contiguous position ids for the finalized chunk
+ position_ids = torch.arange(self.seq_length, dtype=self.tokens_dtype)
+ yield {
+ "input_ids": input_ids,
+ "labels": labels,
+ "attention_mask": attention_mask,
+ "position_ids": position_ids,
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if not example_len or ( | |
| buffer_len + int(add_concat_token) + example_len > self.seq_length | |
| ): | |
| if buffer["input_ids"]: | |
| input_ids = torch.cat(buffer["input_ids"], dim=-1)[ | |
| : self.seq_length | |
| ] | |
| attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[ | |
| : self.seq_length | |
| ] | |
| position_ids = torch.cat(buffer["position_ids"], dim=-1)[ | |
| : self.seq_length | |
| ] | |
| labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length] | |
| if labels.size() == input_ids.size() and ( | |
| attention_mask.size() == input_ids.size() | |
| ): | |
| yield { | |
| "input_ids": input_ids, | |
| "labels": labels, | |
| "attention_mask": attention_mask, | |
| "position_ids": position_ids, | |
| } | |
| else: | |
| LOG.warning( | |
| "Dropping batch due to tensor size mismatch " | |
| f"input_ids: {input_ids.size()}, " | |
| f"labels: {labels.size()}, " | |
| f"attention_mask: {attention_mask.size()}" | |
| ) | |
| buffer = { | |
| if not example_len or ( | |
| buffer_len + int(add_concat_token) + example_len > self.seq_length | |
| ): | |
| if buffer["input_ids"]: | |
| # concatenate everything | |
| input_ids = torch.cat(buffer["input_ids"], dim=-1) | |
| attention_mask = torch.cat(buffer["attention_mask"], dim=-1) | |
| labels = torch.cat(buffer["labels"], dim=-1) | |
| # current length before enforcing fixed size | |
| cur_len = input_ids.size(-1) | |
| # truncate if too long | |
| if cur_len > self.seq_length: | |
| input_ids = input_ids[: self.seq_length] | |
| attention_mask = attention_mask[: self.seq_length] | |
| labels = labels[: self.seq_length] | |
| cur_len = self.seq_length | |
| # pad if too short | |
| if cur_len < self.seq_length: | |
| pad_len = self.seq_length - cur_len | |
| input_ids = torch.cat( | |
| [ | |
| input_ids, | |
| torch.full( | |
| (pad_len,), | |
| self.pad_token_id, | |
| dtype=self.tokens_dtype, | |
| ), | |
| ], | |
| dim=-1, | |
| ) | |
| attention_mask = torch.cat( | |
| [ | |
| attention_mask, | |
| torch.zeros((pad_len,), dtype=self.mask_dtype), | |
| ], | |
| dim=-1, | |
| ) | |
| labels = torch.cat( | |
| [ | |
| labels, | |
| torch.full((pad_len,), -100, dtype=self.tokens_dtype), | |
| ], | |
| dim=-1, | |
| ) | |
| # fixed-length, contiguous positions | |
| position_ids = torch.arange( | |
| self.seq_length, dtype=self.tokens_dtype | |
| ) | |
| yield { | |
| "input_ids": input_ids, | |
| "labels": labels, | |
| "attention_mask": attention_mask, | |
| "position_ids": position_ids, | |
| } | |
| buffer = { | |
| "input_ids": [], | |
| "attention_mask": [], | |
| "position_ids": [], | |
| "labels": [], | |
| } |
🤖 Prompt for AI Agents
In src/axolotl/datasets.py around lines 152-182, the flush path can yield
tensors shorter than self.seq_length; modify the flush so every yielded tensor
is exactly self.seq_length by padding/truncating after concatenation: after
building input_ids, attention_mask, position_ids and labels and slicing to at
most self.seq_length, if any tensor is shorter than self.seq_length pad
input_ids with self.pad_token_id (or 0 if not available), pad labels with
self.ignore_index (commonly -100) so they do not contribute to loss, pad
attention_mask with 0, and pad position_ids with 0 (or appropriate incremental
positions if required) until length == self.seq_length; keep the existing
size-check and warning, then yield the fixed-length batch and reset buffer.
There was a problem hiding this comment.
Fix attention_mask scaling and position_ids; enforce correct dtypes.
attention_mask must be 0/1, not multiplied by idx. position_ids should advance with buffer_len. Also use torch.long.
Apply:
- input_ids_with_concat = torch.tensor(
- input_ids, dtype=self.tokens_dtype
- )
- attention_mask_with_concat = torch.tensor(
- [idx * m for m in attention_mask], dtype=torch.int16
- )
- labels_with_concat = torch.tensor(
- labels, dtype=self.tokens_dtype
- )
- position_ids = torch.arange(
- len(input_ids), dtype=self.tokens_dtype
- )
+ input_ids_with_concat = torch.tensor(
+ input_ids, dtype=self.tokens_dtype
+ )
+ attention_mask_with_concat = torch.tensor(
+ attention_mask, dtype=self.mask_dtype
+ )
+ labels_with_concat = torch.tensor(
+ labels, dtype=self.tokens_dtype
+ )
+ position_ids = torch.arange(
+ buffer_len, buffer_len + len(input_ids), dtype=self.tokens_dtype
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| input_ids_with_concat = torch.tensor( | |
| input_ids, dtype=self.tokens_dtype | |
| ) | |
| attention_mask_with_concat = torch.tensor( | |
| [idx * m for m in attention_mask], dtype=torch.int16 | |
| ) | |
| labels_with_concat = torch.tensor( | |
| labels, dtype=self.tokens_dtype | |
| ) | |
| position_ids = torch.arange( | |
| len(input_ids), dtype=self.tokens_dtype | |
| ) | |
| buffer["input_ids"].append(input_ids_with_concat) | |
| buffer["attention_mask"].append(attention_mask_with_concat) | |
| buffer["labels"].append(labels_with_concat) | |
| buffer["position_ids"].append(position_ids) | |
| buffer_len += len(input_ids) | |
| input_ids_with_concat = torch.tensor( | |
| input_ids, dtype=self.tokens_dtype | |
| ) | |
| attention_mask_with_concat = torch.tensor( | |
| attention_mask, dtype=self.mask_dtype | |
| ) | |
| labels_with_concat = torch.tensor( | |
| labels, dtype=self.tokens_dtype | |
| ) | |
| position_ids = torch.arange( | |
| buffer_len, buffer_len + len(input_ids), dtype=self.tokens_dtype | |
| ) | |
| buffer["input_ids"].append(input_ids_with_concat) | |
| buffer["attention_mask"].append(attention_mask_with_concat) | |
| buffer["labels"].append(labels_with_concat) | |
| buffer["position_ids"].append(position_ids) | |
| buffer_len += len(input_ids) |
🤖 Prompt for AI Agents
In src/axolotl/datasets.py around lines 204–221, fix the attention_mask and
position_ids creation and enforce torch.long dtype: do not scale attention_mask
by idx (keep it as 0/1), convert it to a long tensor (e.g., (attention_mask >
0).long() or torch.tensor(attention_mask, dtype=torch.long)), create
position_ids using the current buffer_len offset (torch.arange(buffer_len,
buffer_len + len(input_ids), dtype=torch.long)) so positions advance correctly,
and construct input_ids and labels tensors with dtype=torch.long as well; then
append these long-typed tensors to the buffer and update buffer_len by
len(input_ids).
Description
refactor for #1783 deprecate dataset_processes in preference of dataset_num_proc
Summary by CodeRabbit
New Features
Refactor
Documentation
Tests
Chores