Skip to content

feat:add support dataset_num_processes#3071

Closed
ved1beta wants to merge 7 commits into
axolotl-ai-cloud:mainfrom
ved1beta:dataset_num_proc
Closed

feat:add support dataset_num_processes#3071
ved1beta wants to merge 7 commits into
axolotl-ai-cloud:mainfrom
ved1beta:dataset_num_proc

Conversation

@ved1beta
Copy link
Copy Markdown
Member

@ved1beta ved1beta commented Aug 15, 2025

Description

refactor for #1783 deprecate dataset_processes in preference of dataset_num_proc

Summary by CodeRabbit

  • New Features

    • Added ConstantLengthDataset to yield fixed-length token sequences from iterable datasets.
    • Introduced AXOLOTL_DATASET_NUM_PROC to control dataset processing concurrency.
    • Added new config key dataset_num_proc with backward-compatible alias and stricter validation.
  • Refactor

    • Replaced dataset_processes with dataset_num_proc across data loading, filtering, mapping, and training steps.
  • Documentation

    • Updated debugging guide and examples to use dataset_num_proc and corresponding CLI flag.
  • Tests

    • Updated tests to reflect the dataset_num_proc configuration.
  • Chores

    • CI and single-GPU runtime now export AXOLOTL_DATASET_NUM_PROC.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Aug 15, 2025

📝 Walkthrough

Walkthrough

Renames the dataset parallelism configuration from dataset_processes to dataset_num_proc across config, environment, code paths, and docs; adds AXOLOTL_DATASET_NUM_PROC handling. Updates builders/trainers/data utilities to use the new key. Introduces ConstantLengthDataset in datasets.py and adds type annotations. Updates tests accordingly.

Changes

Cohort / File(s) Summary of changes
CI/CD env variables
cicd/Dockerfile.jinja, cicd/single_gpu.py
Add AXOLOTL_DATASET_NUM_PROC=8 to Dockerfile and single_gpu environment alongside existing AXOLOTL_DATASET_PROCESSES.
Dev config and docs
devtools/dev_chat_template.yml, docs/debugging.qmd
Replace dataset_processes with dataset_num_proc in template and docs; update CLI flags/examples accordingly.
Config schema and defaults
src/axolotl/utils/schemas/config.py
Add dataset_num_proc field and deprecated alias dataset_processes. Rename validator to default_dataset_num_proc; map alias to new field, warn on deprecation, raise on conflict; default to get_default_process_count().
Default process count utility
src/axolotl/utils/datasets.py
Check AXOLOTL_DATASET_NUM_PROC first when deriving default process count, then fall back to AXOLOTL_DATASET_PROCESSES, RUNPOD_CPU_COUNT, os.cpu_count().
Builders, trainer, and data utils rename
src/axolotl/core/builders/base.py, src/axolotl/utils/trainer.py, src/axolotl/utils/data/rl.py, src/axolotl/utils/data/shared.py, src/axolotl/utils/data/utils.py, src/axolotl/utils/data/wrappers.py
Switch all usages from cfg.dataset_processes to cfg.dataset_num_proc for map/filter/save/process_count; maintain existing control flow and fallbacks.
Datasets module additions
src/axolotl/datasets.py
Add ConstantLengthDataset class yielding fixed-length token sequences; import typing/torch/tokenizer types; annotate TokenizedPromptDataset.process signature.
Tests update
tests/core/test_builders.py, tests/e2e/patched/test_activation_checkpointing.py
Update configs in tests from dataset_processes to dataset_num_proc; values unchanged.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~30 minutes

Possibly related PRs

Suggested labels

ready to merge

Suggested reviewers

  • winglian
  • NanoCode012
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🔭 Outside diff range comments (3)
src/axolotl/utils/data/rl.py (1)

114-120: Guard against num_proc == 0 for HF Datasets.map — fix required

Short: The config validator only fills a default when dataset_num_proc is None (src/axolotl/utils/schemas/config.py), so a user-specified 0 will pass through and can cause HF Datasets to error (datasets.map(num_proc=0)). Coerce falsy 0 → None at call sites (or normalize centrally).

Change to apply (RL file example):

     dataset = dataset.map(
         ds_transform_fn,
-        num_proc=cfg.dataset_num_proc,
+        num_proc=(cfg.dataset_num_proc or None),
         load_from_cache_file=not cfg.is_preprocess,
         desc="Mapping RL Dataset",
         **map_kwargs,
     )

Places that need the same attention (please update to use (cfg.dataset_num_proc or None) or centralize normalization):

  • src/axolotl/utils/data/rl.py — map at ~lines 114–120 and filter at ~lines 234–239
  • src/axolotl/utils/trainer.py — filter_map_kwargs assignment (line ~279) and map calls at ~319, ~336, ~345
  • src/axolotl/utils/data/utils.py — filter_map_kwargs assignment (~line 188)
  • src/axolotl/utils/data/wrappers.py — dataset_kwargs["process_count"] (~line 84) — verify whether None is accepted by that wrapper; otherwise default to get_default_process_count()
  • src/axolotl/core/trainers/base.py — num_processes=self.args.dataset_num_proc (~line 123) — verify downstream API accepts 0

Note: src/axolotl/utils/data/shared.py already guards with cfg.dataset_num_proc or get_default_process_count().

Recommendation: either (A) update all HF datasets.map/filter call sites to pass (cfg.dataset_num_proc or None), or (B) update the config validator (default_dataset_num_proc) to treat 0 as unset (convert 0 → None) so the value never reaches callers as 0.

src/axolotl/utils/datasets.py (1)

6-13: Harden env var parsing in get_default_process_count — fix required

Current implementation can raise ValueError for non-integer env values and may return None if os.cpu_count() is None. I verified call sites and occurrences; please apply the following change.

Files to update / review:

  • src/axolotl/utils/datasets.py — update get_default_process_count() as shown below.
  • Call sites that rely on this behavior:
    • src/axolotl/utils/schemas/config.py:1282 — data["dataset_num_proc"] = get_default_process_count()
    • src/axolotl/utils/data/shared.py:414 — num_workers = cfg.dataset_num_proc or get_default_process_count()
  • Env var occurrences (CI / packaging):
    • cicd/single_gpu.py:69-70 — sets AXOLOTL_DATASET_NUM_PROC / AXOLOTL_DATASET_PROCESSES
    • cicd/Dockerfile.jinja:12-13 — ENV AXOLOTL_DATASET_NUM_PROC / AXOLOTL_DATASET_PROCESSES
  • Deprecated key handling:
    • src/axolotl/utils/schemas/config.py:225 and 1268–1274 — dataset_processes is deprecated and mapped to dataset_num_proc

Suggested patch:

 def get_default_process_count():
-    if axolotl_dataset_num_proc := os.environ.get("AXOLOTL_DATASET_NUM_PROC"):
-        return int(axolotl_dataset_num_proc)
-    if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
-        return int(axolotl_dataset_processes)
-    if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
-        return int(runpod_cpu_count)
-    return os.cpu_count()
+    def _parse_env_int(name):
+        val = os.environ.get(name)
+        if val is None:
+            return None
+        try:
+            parsed = int(val)
+            return parsed if parsed > 0 else None
+        except ValueError:
+            return None
+
+    for key in (
+        "AXOLOTL_DATASET_NUM_PROC",
+        "AXOLOTL_DATASET_PROCESSES",
+        "RUNPOD_CPU_COUNT",
+    ):
+        parsed = _parse_env_int(key)
+        if parsed is not None:
+            return parsed
+
+    # Ensure a sensible minimum of 1
+    return max(os.cpu_count() or 1, 1)

Also add a short docstring to get_default_process_count describing precedence:
AXOLOTL_DATASET_NUM_PROC > AXOLOTL_DATASET_PROCESSES > RUNPOD_CPU_COUNT > os.cpu_count().

src/axolotl/utils/data/shared.py (1)

414-424: Rename looks good; sanitize num_workers to avoid 0/None edge cases

The switch to cfg.dataset_num_proc is consistent. However, if dataset_num_proc is 0 (misconfigured) or get_default_process_count() returns None (rare, but possible), num_workers will become 0/None and cause failures:

  • IterableDataset branch: range(num_workers) and num_proc=num_workers will error/behave unexpectedly.
  • Non-iterable branch: min(..., num_workers) can compute to 0 and pass num_proc=0 to save_to_disk, which is invalid in HF datasets.

Clamp num_workers to at least 1 before use.

Apply this diff:

-    num_workers = cfg.dataset_num_proc or get_default_process_count()
+    num_workers = cfg.dataset_num_proc or get_default_process_count()
+    # Ensure valid worker count (HF datasets expects None or >=1; treat 0/None as 1 here)
+    try:
+        num_workers = int(num_workers) if num_workers is not None else None
+    except (TypeError, ValueError):
+        num_workers = None
+    if not num_workers or num_workers < 1:
+        num_workers = 1
🧹 Nitpick comments (7)
cicd/single_gpu.py (1)

69-71: Keep both envs in sync and centralize the default

Minor improvement: derive the value once (respecting any pre-set environment) and assign it to both keys. This avoids drift if one default changes.

Apply this diff:

-    sp_env["AXOLOTL_DATASET_NUM_PROC"] = "8"
-    sp_env["AXOLOTL_DATASET_PROCESSES"] = "8"
+    num_proc = (
+        os.environ.get("AXOLOTL_DATASET_NUM_PROC")
+        or os.environ.get("AXOLOTL_DATASET_PROCESSES")
+        or "8"
+    )
+    sp_env["AXOLOTL_DATASET_NUM_PROC"] = num_proc
+    # Deprecated; kept for transition. Remove after the deprecation window.
+    sp_env["AXOLOTL_DATASET_PROCESSES"] = num_proc
src/axolotl/utils/data/rl.py (1)

235-241: Mirror the same None/0 guard for filter’s num_proc

Same rationale as above for map: filter(num_proc=0) is risky.

Apply this diff:

     split_datasets[i] = split_datasets[i].filter(
         drop_long,
-        num_proc=cfg.dataset_num_proc,
+        num_proc=(cfg.dataset_num_proc or None),
         load_from_cache_file=not cfg.is_preprocess,
         desc="Dropping Long Sequences",
     )
src/axolotl/utils/data/utils.py (1)

186-191: Coerce falsy dataset_num_proc to None to avoid HF errors

Align behavior with a safe default: ensure num_proc is None when unset or 0.

Apply this diff:

     filter_map_kwargs = {}
     if not isinstance(dataset, IterableDataset):
-        filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
+        filter_map_kwargs["num_proc"] = (cfg.dataset_num_proc or None)
         filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
docs/debugging.qmd (1)

32-32: Add mention of the new env var for completeness

Since AXOLOTL_DATASET_NUM_PROC is now supported, it’s useful to surface it here alongside config/CLI.

Apply this diff:

-    - Set `dataset_num_proc: 1` in your axolotl config or run the training command with `--dataset_num_proc=1`.
+    - Set `dataset_num_proc: 1` in your axolotl config, export `AXOLOTL_DATASET_NUM_PROC=1`, or run the training command with `--dataset_num_proc=1`.
tests/test_datasets.py (1)

144-145: Rename to dataset_num_proc across dataset tests — LGTM; consider adding a deprecated-key coverage test

Ran the grep you provided over tests.

  • Findings:

    • "dataset_num_proc" is used across tests (examples: tests/test_packed_dataset.py, tests/test_exact_deduplication.py, tests/test_datasets.py, tests/e2e/test_llama_pretrain.py, tests/core/test_builders.py, tests/e2e/patched/test_activation_checkpointing.py).
    • "dataset_processes" (deprecated) has 0 occurrences in tests.
  • Recommendation (optional but useful): add a small unit test (config-normalization scope) that:

    • sets only dataset_processes and asserts it maps to cfg.dataset_num_proc;
    • sets both keys with different values and asserts dataset_num_proc wins and a deprecation warning is emitted.

I can draft that minimal test if you want.

src/axolotl/utils/trainer.py (1)

279-279: Optional: Avoid setting num_proc key when it’s None

HF datasets treats num_proc=None as single-process. You can simplify by only setting the key if cfg.dataset_num_proc is truthy to reduce kwargs noise.

-        filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
+        if cfg.dataset_num_proc:
+            filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
src/axolotl/utils/schemas/config.py (1)

226-235: Deprecation field is clear; consider marking JSON schema as deprecated for tooling

The deprecation messaging is good. For better JSON schema/OpenAPI tooling support, set a boolean deprecated flag in json_schema_extra as well.

     dataset_processes: int | None = Field(
         default=None,
         deprecated="Use `dataset_num_proc` instead. This parameter will be removed in a future version.",
         json_schema_extra={
             "description": (
                 "DEPRECATED: Use `dataset_num_proc` instead. "
                 "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
                 "For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
             ),
+            "deprecated": True,
         },
     )
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between d1de6f5 and 0ca4645.

📒 Files selected for processing (18)
  • cicd/Dockerfile.jinja (1 hunks)
  • cicd/single_gpu.py (1 hunks)
  • devtools/dev_chat_template.yml (1 hunks)
  • docs/debugging.qmd (2 hunks)
  • src/axolotl/core/builders/base.py (1 hunks)
  • src/axolotl/utils/data/rl.py (2 hunks)
  • src/axolotl/utils/data/shared.py (1 hunks)
  • src/axolotl/utils/data/utils.py (1 hunks)
  • src/axolotl/utils/data/wrappers.py (1 hunks)
  • src/axolotl/utils/datasets.py (1 hunks)
  • src/axolotl/utils/schemas/config.py (2 hunks)
  • src/axolotl/utils/trainer.py (5 hunks)
  • tests/core/test_builders.py (2 hunks)
  • tests/e2e/patched/test_activation_checkpointing.py (1 hunks)
  • tests/e2e/test_llama_pretrain.py (1 hunks)
  • tests/test_datasets.py (7 hunks)
  • tests/test_exact_deduplication.py (1 hunks)
  • tests/test_packed_dataset.py (1 hunks)
🧰 Additional context used
🧬 Code Graph Analysis (8)
src/axolotl/utils/data/utils.py (2)
tests/test_exact_deduplication.py (1)
  • cfg (201-216)
src/axolotl/integrations/base.py (2)
  • cfg (352-353)
  • cfg (356-357)
src/axolotl/core/builders/base.py (2)
tests/test_exact_deduplication.py (1)
  • cfg (201-216)
src/axolotl/integrations/base.py (2)
  • cfg (352-353)
  • cfg (356-357)
tests/core/test_builders.py (1)
tests/test_exact_deduplication.py (1)
  • cfg (201-216)
src/axolotl/utils/data/wrappers.py (1)
tests/test_exact_deduplication.py (1)
  • cfg (201-216)
src/axolotl/utils/data/shared.py (2)
tests/test_exact_deduplication.py (1)
  • cfg (201-216)
src/axolotl/utils/datasets.py (1)
  • get_default_process_count (6-13)
src/axolotl/utils/trainer.py (2)
tests/test_exact_deduplication.py (1)
  • cfg (201-216)
src/axolotl/integrations/base.py (2)
  • cfg (352-353)
  • cfg (356-357)
src/axolotl/utils/data/rl.py (2)
tests/test_exact_deduplication.py (1)
  • cfg (201-216)
src/axolotl/integrations/base.py (2)
  • cfg (352-353)
  • cfg (356-357)
src/axolotl/utils/schemas/config.py (1)
src/axolotl/utils/datasets.py (1)
  • get_default_process_count (6-13)
🔇 Additional comments (12)
devtools/dev_chat_template.yml (1)

16-16: Renamed key aligns with the new canonical config

Switching to dataset_num_proc here matches the updated schema and docs. Looks good.

tests/e2e/patched/test_activation_checkpointing.py (1)

73-74: Test config updated to the new key

Using dataset_num_proc in this test fixture is consistent with the rename and should exercise the new code path.

src/axolotl/utils/data/wrappers.py (1)

83-86: Approved — cfg.dataset_num_proc is correct; only deprecated alias remains

Search shows only deprecation/alias handling for dataset_processes; the wrapper change is safe.

  • src/axolotl/utils/schemas/config.py:225 — dataset_processes Field with deprecated="Use dataset_num_proc instead."
  • src/axolotl/utils/schemas/config.py:1268–1280 — default_dataset_num_proc copies dataset_processesdataset_num_proc and logs deprecation warnings.
  • src/axolotl/utils/data/wrappers.py:83–86 — uses cfg.dataset_num_proc for process_count (change under review).
tests/e2e/test_llama_pretrain.py (1)

33-34: LGTM: renamed key to dataset_num_proc

Consistent with the PR-wide rename; keeps behavior unchanged (value=1).

tests/test_exact_deduplication.py (1)

213-214: LGTM: test fixture key rename

Matches the new canonical config key and maintains the same semantics (4 workers).

tests/test_packed_dataset.py (1)

102-102: Renamed key to dataset_num_proc — LGTM

Config rename in the test input matches the repo-wide migration and keeps the same semantics.

tests/core/test_builders.py (1)

85-85: Consistent switch to dataset_num_proc in builders tests — LGTM

Both the base fixture and the per-test override now use the canonical key; aligns with updated builders reading cfg.dataset_num_proc.

Also applies to: 446-446

docs/debugging.qmd (1)

104-104: CLI args updated to dataset_num_proc — LGTM

Docs reflect the new canonical flag in the VSCode launch example.

src/axolotl/utils/data/shared.py (1)

433-440: Prevent passing num_proc=0 to save_to_disk

With the clamped num_workers above, this becomes safe. Without it, the expression can yield 0 and break. No change needed here if you adopt the clamp; otherwise, consider guarding num_workers to be >=1.

src/axolotl/utils/trainer.py (1)

279-279: Consistent rename to dataset_num_proc — ensure value is >= 1 everywhere

All updates to use cfg.dataset_num_proc look correct and consistent. One behavioral edge case: if cfg.dataset_num_proc is 0 (or negative), HF datasets and MultipackBatchSampler will misbehave or error. In shared.py you already fall back via “or get_default_process_count()”, but here you pass the value through directly.

Prefer centralizing a clamp at config load (recommended below), or locally coalescing to None/1.

If you don’t adopt the config-level clamp, a minimal local guard would look like:

-        filter_map_kwargs["num_proc"] = cfg.dataset_num_proc
+        filter_map_kwargs["num_proc"] = (
+            None if cfg.dataset_num_proc in (None, 0) else max(1, cfg.dataset_num_proc)
+        )

Same applies to:

  • Line 319: num_proc in map()
  • Line 336: num_proc in PoSE map()
  • Line 345: num_proc in eval PoSE map()
  • Line 470: num_processes in MultipackBatchSampler

Also applies to: 319-319, 336-336, 345-345, 470-472

src/axolotl/utils/schemas/config.py (2)

17-24: Confirm that get_default_process_count() behavior is as intended

This validator relies on get_default_process_count() which now prioritizes AXOLOTL_DATASET_NUM_PROC over AXOLOTL_DATASET_PROCESSES and then RUNPOD_CPU_COUNT. That’s consistent with the PR goals. Just ensure documentation mentions the new env var and that “0” is not a valid value (after the clamp).

Would you like me to update docs to explicitly state:

  • dataset_num_proc must be >= 1
  • AXOLOTL_DATASET_NUM_PROC is preferred over AXOLOTL_DATASET_PROCESSES
  • If neither is set, defaults to os.cpu_count() (or 1 if undetermined)?

1-1: Repo-wide check — OK: only compat-layer usages of dataset_processes remain

Search found references only in the compatibility code inside src/axolotl/utils/schemas/config.py; no other occurrences in the repo.

  • src/axolotl/utils/schemas/config.py — line ~225: schema field dataset_processes: int | None = Field(..., deprecated="Use \dataset_num_proc` instead...")`
  • src/axolotl/utils/schemas/config.py — lines ~1268–1279: default_dataset_num_proc classmethod that copies dataset_processes -> dataset_num_proc and logs deprecation warnings

No further action required.

Comment thread cicd/Dockerfile.jinja
Comment on lines +12 to 13
ENV AXOLOTL_DATASET_NUM_PROC="8"
ENV AXOLOTL_DATASET_PROCESSES="8"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Add deprecation note or plan for the legacy env var

Setting both AXOLOTL_DATASET_NUM_PROC and AXOLOTL_DATASET_PROCESSES is fine for a transition period. Consider annotating the legacy variable to avoid future confusion and to make the removal plan explicit.

Apply this diff to add a brief deprecation note:

 ENV AXOLOTL_DATASET_NUM_PROC="8"
-ENV AXOLOTL_DATASET_PROCESSES="8"
+# Deprecated: kept for backward-compat. Prefer AXOLOTL_DATASET_NUM_PROC; remove after deprecation window.
+ENV AXOLOTL_DATASET_PROCESSES="8"

Run this to verify the precedence in code (NUM_PROC should be preferred over PROCESSES):


🏁 Script executed:

#!/bin/bash
# Expect: A code path that reads AXOLOTL_DATASET_NUM_PROC first, then falls back to AXOLOTL_DATASET_PROCESSES.
rg -nP -C3 'AXOLOTL_DATASET_(NUM_PROC|PROCESSES)' src | sed -n '1,200p'

# Also check for any lingering config key usages of `dataset_processes` outside of deprecation/alias handling.
rg -nP -C2 '\bdataset_processes\b' | sed -n '1,200p'

Length of output: 2466


Annotate legacy AXOLOTL_DATASET_PROCESSES in cicd/Dockerfile.jinja

Verified: AXOLOTL_DATASET_NUM_PROC is preferred at runtime and dataset_processes is already deprecated/migrated in the config — please add a deprecation note in the Dockerfile to avoid confusion.

Files to note:

  • cicd/Dockerfile.jinja — lines 12–13: add deprecation comment for the legacy env var.
  • Verified precedence and deprecation handling in:
    • src/axolotl/utils/datasets.py — get_default_process_count() checks AXOLOTL_DATASET_NUM_PROC then falls back to AXOLOTL_DATASET_PROCESSES (lines ~7–11).
    • src/axolotl/utils/schemas/config.py — dataset_processes is declared deprecated (line ~225) and migrated into dataset_num_proc with warnings (lines ~1268–1279).

Apply this diff:

 ENV AXOLOTL_DATASET_NUM_PROC="8"
-ENV AXOLOTL_DATASET_PROCESSES="8"
+# Deprecated: kept for backward-compat. Prefer AXOLOTL_DATASET_NUM_PROC; remove after deprecation window.
+ENV AXOLOTL_DATASET_PROCESSES="8"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
ENV AXOLOTL_DATASET_NUM_PROC="8"
ENV AXOLOTL_DATASET_PROCESSES="8"
ENV AXOLOTL_DATASET_NUM_PROC="8"
# Deprecated: kept for backward-compat. Prefer AXOLOTL_DATASET_NUM_PROC; remove after deprecation window.
ENV AXOLOTL_DATASET_PROCESSES="8"
🤖 Prompt for AI Agents
In cicd/Dockerfile.jinja around lines 12–13, the legacy
AXOLOTL_DATASET_PROCESSES env var is set without any deprecation note; add a
concise comment next to that ENV line stating that AXOLOTL_DATASET_NUM_PROC is
the preferred runtime variable, AXOLOTL_DATASET_PROCESSES is deprecated and kept
only for backward compatibility, and that code now prefers
AXOLOTL_DATASET_NUM_PROC (see src/axolotl/utils/datasets.py and
src/axolotl/utils/schemas/config.py for migration/deprecation handling). Ensure
the comment is short, mentions backward compatibility, and does not remove the
existing ENV assignment.

Comment thread src/axolotl/core/builders/base.py Outdated
Comment thread src/axolotl/utils/schemas/config.py
Comment thread cicd/Dockerfile.jinja
@winglian winglian requested a review from djsaunde August 18, 2025 12:43
Copy link
Copy Markdown
Collaborator

@djsaunde djsaunde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks.

Comment thread docs/debugging.qmd Outdated
Co-authored-by: Dan Saunders <danjsaund@gmail.com>
Comment thread src/axolotl/core/builders/base.py Outdated
Comment on lines +522 to +523
if self.cfg.dataset_num_proc:
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_num_proc
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could add it to the above list instead of separate if condition here

Comment on lines +1275 to +1279
else:
LOG.warning(
"Both dataset_processes and dataset_num_proc are set. "
"Using dataset_num_proc and ignoring dataset_processes."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably raise an error as it's a conflict?

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
src/axolotl/utils/schemas/config.py (3)

226-235: Use boolean JSON Schema deprecation flag; keep explanation in description

deprecated= should be a boolean in JSON Schema. Move it under json_schema_extra as deprecated: True and avoid duplicating the message.

-    dataset_processes: int | None = Field(
-        default=None,
-        deprecated="Use `dataset_num_proc` instead. This parameter will be removed in a future version.",
-        json_schema_extra={
-            "description": (
-                "DEPRECATED: Use `dataset_num_proc` instead. "
-                "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
-                "For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
-            )
-        },
-    )
+    dataset_processes: int | None = Field(
+        default=None,
+        json_schema_extra={
+            "deprecated": True,
+            "description": (
+                "DEPRECATED: Use `dataset_num_proc` instead. "
+                "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
+                "For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
+            ),
+        },
+    )

236-244: Enforce >=1 at schema level and document env override order

Add ge=1 and note the override precedence including the new env var.

-    dataset_num_proc: int | None = Field(
-        default=None,
-        json_schema_extra={
-            "description": (
-                "The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` if not set.\n"
-                "For Runpod VMs, it will default to number of vCPUs via RUNPOD_CPU_COUNT."
-            )
-        },
-    )
+    dataset_num_proc: int | None = Field(
+        default=None,
+        ge=1,
+        json_schema_extra={
+            "description": (
+                "The maximum number of processes to use while preprocessing your input dataset. "
+                "Override/default precedence: AXOLOTL_DATASET_NUM_PROC > AXOLOTL_DATASET_PROCESSES > RUNPOD_CPU_COUNT > os.cpu_count()."
+            )
+        },
+    )

236-244: Support dataset_num_processes alias or update docs: the PR title/docs refer to dataset_num_processes but the schema only accepts dataset_num_proc; either add validation_alias=AliasChoices("dataset_num_proc","dataset_num_processes") on that Field or align the title/docs accordingly.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between def8288 and b7c27ce.

📒 Files selected for processing (2)
  • src/axolotl/core/builders/base.py (1 hunks)
  • src/axolotl/utils/schemas/config.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/core/builders/base.py
🧰 Additional context used
🧬 Code graph analysis (1)
src/axolotl/utils/schemas/config.py (1)
src/axolotl/utils/datasets.py (1)
  • get_default_process_count (6-13)
🪛 Ruff (0.12.2)
src/axolotl/utils/schemas/config.py

1276-1279: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (7)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.8.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
  • GitHub Check: preview

Comment on lines +1267 to 1284
def default_dataset_num_proc(cls, data):
if data.get("dataset_processes") is not None:
if data.get("dataset_num_proc") is None:
data["dataset_num_proc"] = data["dataset_processes"]
LOG.warning(
"dataset_processes is deprecated and will be removed in a future version. "
"Please use dataset_num_proc instead."
)
else:
raise ValueError(
"Both dataset_processes and dataset_num_proc are set. "
"Using dataset_num_proc and ignoring dataset_processes."
)

if data.get("dataset_num_proc") is None:
data["dataset_num_proc"] = get_default_process_count()

return data
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Harden defaulting/parsing, clamp to >=1, and fix conflict error message

  • Current flow can yield 0/None or non-int, leading to runtime errors downstream.
  • Conflict error message says it will ignore one value but raises instead—make the message consistent.
 @model_validator(mode="before")
 @classmethod
 def default_dataset_num_proc(cls, data):
-    if data.get("dataset_processes") is not None:
-        if data.get("dataset_num_proc") is None:
-            data["dataset_num_proc"] = data["dataset_processes"]
-            LOG.warning(
-                "dataset_processes is deprecated and will be removed in a future version. "
-                "Please use dataset_num_proc instead."
-            )
-        else:
-            raise ValueError(
-                "Both dataset_processes and dataset_num_proc are set. "
-                "Using dataset_num_proc and ignoring dataset_processes."
-            )
-
-    if data.get("dataset_num_proc") is None:
-        data["dataset_num_proc"] = get_default_process_count()
+    # Handle deprecated -> canonical
+    if data.get("dataset_processes") is not None and data.get("dataset_num_proc") is None:
+        data["dataset_num_proc"] = data["dataset_processes"]
+        LOG.warning(
+            "dataset_processes is deprecated and will be removed in a future version. "
+            "Please use dataset_num_proc instead."
+        )
+    elif data.get("dataset_processes") is not None and data.get("dataset_num_proc") is not None:
+        # Conflict: be explicit and fail fast
+        raise ValueError(
+            "Conflicting config: dataset_processes and dataset_num_proc cannot both be set. "
+            "Remove dataset_processes."
+        )
+
+    # Fill default if unset, coerce and clamp to >=1
+    if data.get("dataset_num_proc") is None:
+        try:
+            default = get_default_process_count()
+        except (TypeError, ValueError):
+            default = None
+        data["dataset_num_proc"] = default if (isinstance(default, int) and default > 0) else 1
+    else:
+        try:
+            val = int(data["dataset_num_proc"])
+        except (TypeError, ValueError):
+            raise ValueError("dataset_num_proc must be an integer >= 1") from None
+        if val < 1:
+            LOG.warning("dataset_num_proc must be >= 1; clamping to 1.")
+            val = 1
+        data["dataset_num_proc"] = val
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def default_dataset_num_proc(cls, data):
if data.get("dataset_processes") is not None:
if data.get("dataset_num_proc") is None:
data["dataset_num_proc"] = data["dataset_processes"]
LOG.warning(
"dataset_processes is deprecated and will be removed in a future version. "
"Please use dataset_num_proc instead."
)
else:
raise ValueError(
"Both dataset_processes and dataset_num_proc are set. "
"Using dataset_num_proc and ignoring dataset_processes."
)
if data.get("dataset_num_proc") is None:
data["dataset_num_proc"] = get_default_process_count()
return data
@model_validator(mode="before")
@classmethod
def default_dataset_num_proc(cls, data):
# Handle deprecated -> canonical
if data.get("dataset_processes") is not None and data.get("dataset_num_proc") is None:
data["dataset_num_proc"] = data["dataset_processes"]
LOG.warning(
"dataset_processes is deprecated and will be removed in a future version. "
"Please use dataset_num_proc instead."
)
elif data.get("dataset_processes") is not None and data.get("dataset_num_proc") is not None:
# Conflict: be explicit and fail fast
raise ValueError(
"Conflicting config: dataset_processes and dataset_num_proc cannot both be set. "
"Remove dataset_processes."
)
# Fill default if unset, coerce and clamp to >=1
if data.get("dataset_num_proc") is None:
try:
default = get_default_process_count()
except (TypeError, ValueError):
default = None
data["dataset_num_proc"] = default if (isinstance(default, int) and default > 0) else 1
else:
try:
val = int(data["dataset_num_proc"])
except (TypeError, ValueError):
raise ValueError("dataset_num_proc must be an integer >= 1") from None
if val < 1:
LOG.warning("dataset_num_proc must be >= 1; clamping to 1.")
val = 1
data["dataset_num_proc"] = val
return data
🧰 Tools
🪛 Ruff (0.12.2)

1276-1279: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In src/axolotl/utils/schemas/config.py around lines 1267-1284, harden parsing in
default_dataset_num_proc: when reading dataset_processes/dataset_num_proc coerce
values to int (raise ValueError if not convertible), prefer dataset_num_proc
when both are provided but do not raise — replace the current raise with a
LOG.warning that clearly states dataset_num_proc will be used and
dataset_processes ignored, ensure the final dataset_num_proc is clamped to at
least 1 (if <=0 set to 1 and LOG.warning about the clamp), and return the
sanitized data; this prevents None/0/non-int downstream errors and makes the
conflict message consistent with behavior.

@ved1beta ved1beta closed this Sep 4, 2025
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (5)
src/axolotl/datasets.py (5)

48-74: Use column_names instead of features.keys() for remove_columns.

More robust across Dataset variants.

Apply:

-        features = dataset.features.keys()
+        columns = dataset.column_names
@@
-            remove_columns=features,
+            remove_columns=columns,

76-92: Guard IterableDataset without features and pass remove_columns conditionally.

Some streaming datasets don’t populate features.

Apply:

     if isinstance(dataset, IterableDataset):
         map_kwargs = {}
         if prompt_tokenizer.supports_batched:
             map_kwargs["batched"] = True
-        features = list(dataset.features.keys())
-        return dataset.map(
-            prompt_tokenizer.tokenize_prompt,
-            remove_columns=features,
-            **map_kwargs,
-        )
+        feats = getattr(dataset, "features", None)
+        if feats is not None:
+            map_kwargs["remove_columns"] = list(feats.keys())
+        return dataset.map(prompt_tokenizer.tokenize_prompt, **map_kwargs)

105-115: Set safe concat/pad tokens; avoid None EOS.

Fallback to PAD when EOS is missing; also precompute pad_token_id.

Apply:

         self.tokenizer = tokenizer
-        self.concat_token_id = tokenizer.eos_token_id
+        self.concat_token_id = (
+            tokenizer.eos_token_id
+            if tokenizer.eos_token_id is not None
+            else tokenizer.pad_token_id
+        )
+        self.pad_token_id = (
+            tokenizer.pad_token_id
+            if tokenizer.pad_token_id is not None
+            else self.concat_token_id
+        )
         self.datasets: list[IterableDataset] = datasets
         self.seq_length = seq_length

191-199: Avoid dropping long examples; chunk instead.

Instead of skipping, split long records into seq_length-sized shards (with optional overlap) to preserve data.

I can provide a small helper to slice example["input_ids"]/labels into multiple sub-examples if you want it in this PR.


94-104: Interleaving note.

If multiple datasets are provided, consider round-robin interleaving to reduce domain burstiness.

I can sketch an iterator that cycles over self.datasets and yields one example at a time with a cap per-source.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between b7c27ce and 87283e6.

📒 Files selected for processing (1)
  • src/axolotl/datasets.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/axolotl/datasets.py (2)
src/axolotl/utils/logging.py (1)
  • get_logger (42-49)
src/axolotl/prompt_tokenizers.py (1)
  • PromptTokenizingStrategy (43-105)
🔇 Additional comments (1)
src/axolotl/datasets.py (1)

32-47: Wiring of dataset_num_proc to process_count is correct. The dataset_kwargs in src/axolotl/utils/data/wrappers.py maps cfg.dataset_num_proc into the process_count parameter of TokenizedPromptDataset, and all explicit TokenizedPromptDataset instantiations in tests likewise pass process_count. No further changes needed.

Comment thread src/axolotl/datasets.py
Comment on lines +116 to +124
vocab_size = len(tokenizer.get_vocab())

if vocab_size <= torch.iinfo(torch.int16).max:
self.tokens_dtype = torch.int16
elif vocab_size <= torch.iinfo(torch.int32).max:
self.tokens_dtype = torch.int32
else:
self.tokens_dtype = torch.int64

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Token dtype must be torch.long for model embeddings.

Using int16/int32 for token indices will break nn.Embedding and many HF models.

Apply:

-        vocab_size = len(tokenizer.get_vocab())
-
-        if vocab_size <= torch.iinfo(torch.int16).max:
-            self.tokens_dtype = torch.int16
-        elif vocab_size <= torch.iinfo(torch.int32).max:
-            self.tokens_dtype = torch.int32
-        else:
-            self.tokens_dtype = torch.int64
+        # Embedding layers expect Long tensor indices; keep masks integral too.
+        self.tokens_dtype = torch.long
+        self.mask_dtype = torch.long
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
vocab_size = len(tokenizer.get_vocab())
if vocab_size <= torch.iinfo(torch.int16).max:
self.tokens_dtype = torch.int16
elif vocab_size <= torch.iinfo(torch.int32).max:
self.tokens_dtype = torch.int32
else:
self.tokens_dtype = torch.int64
# Embedding layers expect Long tensor indices; keep masks integral too.
self.tokens_dtype = torch.long
self.mask_dtype = torch.long
🤖 Prompt for AI Agents
In src/axolotl/datasets.py around lines 116-124, the code currently chooses
int16/int32/int64 based on vocab_size, but token indices must be torch.long
(int64) for nn.Embedding and HF models. Replace the branching logic so
self.tokens_dtype is always set to torch.long (torch.int64); update any
downstream casts/usages to use this dtype to ensure compatibility with
embeddings.

Comment thread src/axolotl/datasets.py
Comment on lines +152 to +182
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
: self.seq_length
]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
)
buffer = {
"input_ids": [],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Guarantee constant-length output by padding/truncation at flush.

Currently the final batch can be shorter than seq_length, violating the class contract and risking collate failures.

Apply:

-                    if buffer["input_ids"]:
-                        input_ids = torch.cat(buffer["input_ids"], dim=-1)[
-                            : self.seq_length
-                        ]
-                        attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
-                            : self.seq_length
-                        ]
-                        position_ids = torch.cat(buffer["position_ids"], dim=-1)[
-                            : self.seq_length
-                        ]
-                        labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
-                        if labels.size() == input_ids.size() and (
-                            attention_mask.size() == input_ids.size()
-                        ):
-                            yield {
-                                "input_ids": input_ids,
-                                "labels": labels,
-                                "attention_mask": attention_mask,
-                                "position_ids": position_ids,
-                            }
-                        else:
-                            LOG.warning(
-                                "Dropping batch due to tensor size mismatch "
-                                f"input_ids: {input_ids.size()}, "
-                                f"labels: {labels.size()}, "
-                                f"attention_mask: {attention_mask.size()}"
-                            )
+                    if buffer["input_ids"]:
+                        input_ids = torch.cat(buffer["input_ids"], dim=-1)
+                        attention_mask = torch.cat(buffer["attention_mask"], dim=-1)
+                        labels = torch.cat(buffer["labels"], dim=-1)
+                        cur_len = input_ids.size(-1)
+                        # truncate if overflowed
+                        if cur_len > self.seq_length:
+                            input_ids = input_ids[: self.seq_length]
+                            attention_mask = attention_mask[: self.seq_length]
+                            labels = labels[: self.seq_length]
+                            cur_len = self.seq_length
+                        # pad if underfull
+                        if cur_len < self.seq_length:
+                            pad_len = self.seq_length - cur_len
+                            input_ids = torch.cat(
+                                [input_ids,
+                                 torch.full((pad_len,), self.pad_token_id, dtype=self.tokens_dtype)],
+                                dim=-1,
+                            )
+                            attention_mask = torch.cat(
+                                [attention_mask,
+                                 torch.zeros((pad_len,), dtype=self.mask_dtype)],
+                                dim=-1,
+                            )
+                            labels = torch.cat(
+                                [labels,
+                                 torch.full((pad_len,), -100, dtype=self.tokens_dtype)],
+                                dim=-1,
+                            )
+                        # recompute contiguous position ids for the finalized chunk
+                        position_ids = torch.arange(self.seq_length, dtype=self.tokens_dtype)
+                        yield {
+                            "input_ids": input_ids,
+                            "labels": labels,
+                            "attention_mask": attention_mask,
+                            "position_ids": position_ids,
+                        }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
input_ids = torch.cat(buffer["input_ids"], dim=-1)[
: self.seq_length
]
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
):
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
"Dropping batch due to tensor size mismatch "
f"input_ids: {input_ids.size()}, "
f"labels: {labels.size()}, "
f"attention_mask: {attention_mask.size()}"
)
buffer = {
if not example_len or (
buffer_len + int(add_concat_token) + example_len > self.seq_length
):
if buffer["input_ids"]:
# concatenate everything
input_ids = torch.cat(buffer["input_ids"], dim=-1)
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)
labels = torch.cat(buffer["labels"], dim=-1)
# current length before enforcing fixed size
cur_len = input_ids.size(-1)
# truncate if too long
if cur_len > self.seq_length:
input_ids = input_ids[: self.seq_length]
attention_mask = attention_mask[: self.seq_length]
labels = labels[: self.seq_length]
cur_len = self.seq_length
# pad if too short
if cur_len < self.seq_length:
pad_len = self.seq_length - cur_len
input_ids = torch.cat(
[
input_ids,
torch.full(
(pad_len,),
self.pad_token_id,
dtype=self.tokens_dtype,
),
],
dim=-1,
)
attention_mask = torch.cat(
[
attention_mask,
torch.zeros((pad_len,), dtype=self.mask_dtype),
],
dim=-1,
)
labels = torch.cat(
[
labels,
torch.full((pad_len,), -100, dtype=self.tokens_dtype),
],
dim=-1,
)
# fixed-length, contiguous positions
position_ids = torch.arange(
self.seq_length, dtype=self.tokens_dtype
)
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
buffer = {
"input_ids": [],
"attention_mask": [],
"position_ids": [],
"labels": [],
}
🤖 Prompt for AI Agents
In src/axolotl/datasets.py around lines 152-182, the flush path can yield
tensors shorter than self.seq_length; modify the flush so every yielded tensor
is exactly self.seq_length by padding/truncating after concatenation: after
building input_ids, attention_mask, position_ids and labels and slicing to at
most self.seq_length, if any tensor is shorter than self.seq_length pad
input_ids with self.pad_token_id (or 0 if not available), pad labels with
self.ignore_index (commonly -100) so they do not contribute to loss, pad
attention_mask with 0, and pad position_ids with 0 (or appropriate incremental
positions if required) until length == self.seq_length; keep the existing
size-check and warning, then yield the fixed-length batch and reset buffer.

Comment thread src/axolotl/datasets.py Outdated
Comment on lines 204 to 221
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix attention_mask scaling and position_ids; enforce correct dtypes.

attention_mask must be 0/1, not multiplied by idx. position_ids should advance with buffer_len. Also use torch.long.

Apply:

-                        input_ids_with_concat = torch.tensor(
-                            input_ids, dtype=self.tokens_dtype
-                        )
-                        attention_mask_with_concat = torch.tensor(
-                            [idx * m for m in attention_mask], dtype=torch.int16
-                        )
-                        labels_with_concat = torch.tensor(
-                            labels, dtype=self.tokens_dtype
-                        )
-                        position_ids = torch.arange(
-                            len(input_ids), dtype=self.tokens_dtype
-                        )
+                        input_ids_with_concat = torch.tensor(
+                            input_ids, dtype=self.tokens_dtype
+                        )
+                        attention_mask_with_concat = torch.tensor(
+                            attention_mask, dtype=self.mask_dtype
+                        )
+                        labels_with_concat = torch.tensor(
+                            labels, dtype=self.tokens_dtype
+                        )
+                        position_ids = torch.arange(
+                            buffer_len, buffer_len + len(input_ids), dtype=self.tokens_dtype
+                        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
input_ids_with_concat = torch.tensor(
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)
input_ids_with_concat = torch.tensor(
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
attention_mask, dtype=self.mask_dtype
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
buffer_len, buffer_len + len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)
🤖 Prompt for AI Agents
In src/axolotl/datasets.py around lines 204–221, fix the attention_mask and
position_ids creation and enforce torch.long dtype: do not scale attention_mask
by idx (keep it as 0/1), convert it to a long tensor (e.g., (attention_mask >
0).long() or torch.tensor(attention_mask, dtype=torch.long)), create
position_ids using the current buffer_len offset (torch.arange(buffer_len,
buffer_len + len(input_ids), dtype=torch.long)) so positions advance correctly,
and construct input_ids and labels tensors with dtype=torch.long as well; then
append these long-typed tensors to the buffer and update buffer_len by
len(input_ids).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants