Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/axolotl/utils/data/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.utils.data.utils import deduplicate_and_log_datasets, md5
from axolotl.utils.datasets import get_default_process_count
from axolotl.utils.dict import DictDefault
from axolotl.utils.logging import get_logger

Expand Down Expand Up @@ -410,7 +411,7 @@ def save_preprocessed_dataset(
) -> None:
"""Save preprocessed dataset to disk and optionally push to the HF Hub."""
prepared_ds_path = get_prepared_dataset_path(cfg, dataset_hash)
num_workers = cfg.dataset_processes
num_workers = cfg.dataset_processes or get_default_process_count()
if isinstance(dataset, IterableDataset):
ds_from_iter = Dataset.from_generator(
functools.partial(_generate_from_iterable_dataset, dataset),
Expand All @@ -432,7 +433,7 @@ def save_preprocessed_dataset(
os.makedirs(prepared_ds_path, exist_ok=True)
dataset.save_to_disk(
str(prepared_ds_path),
num_proc=num_workers,
num_proc=min(max(1, len(dataset) // 8), num_workers),
max_shard_size=None,
num_shards=cfg.num_dataset_shards_to_save,
)
Expand Down
11 changes: 11 additions & 0 deletions src/axolotl/utils/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""helper functions for datasets"""

import os


def get_default_process_count():
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
return int(axolotl_dataset_processes)
if runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
return int(runpod_cpu_count)
return os.cpu_count()
9 changes: 2 additions & 7 deletions src/axolotl/utils/schemas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

# pylint: disable=too-many-lines

import os
from typing import Annotated, Any, Literal

from annotated_types import MinLen
Expand All @@ -15,6 +14,7 @@
model_validator,
)

from axolotl.utils.datasets import get_default_process_count
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.datasets import (
DatasetConfig,
Expand Down Expand Up @@ -1211,11 +1211,6 @@ def default_dataloader_opts(cls, data):
@classmethod
def default_dataset_processes(cls, data):
if data.get("dataset_processes") is None:
if axolotl_dataset_processes := os.environ.get("AXOLOTL_DATASET_PROCESSES"):
data["dataset_processes"] = int(axolotl_dataset_processes)
elif runpod_cpu_count := os.environ.get("RUNPOD_CPU_COUNT"):
data["dataset_processes"] = int(runpod_cpu_count)
else:
data["dataset_processes"] = os.cpu_count()
data["dataset_processes"] = get_default_process_count()

return data
2 changes: 2 additions & 0 deletions tests/core/test_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def fixture_base_cfg():
"ddp_timeout": 1800,
"ddp_bucket_cap_mb": 25,
"ddp_broadcast_buffers": False,
"dataset_processes": 4,
}
)

Expand Down Expand Up @@ -440,6 +441,7 @@ def test_custom_optimizer_cls_and_kwargs(
]
else:
raise ValueError(f"Unhandled cfg_string: {cfg_string}")
cfg["dataset_processes"] = 4

if cfg_string == "grpo_cfg":
rewards_dir = tmp_path / "rewards_test"
Expand Down
7 changes: 7 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def test_load_from_save_to_disk(self, tokenizer, dataset_fixture):
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)

Expand Down Expand Up @@ -179,6 +180,7 @@ def test_load_from_dir_of_parquet(self, tokenizer, dataset_fixture):
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)

Expand Down Expand Up @@ -217,6 +219,7 @@ def test_load_from_dir_of_json(self, tokenizer, dataset_fixture):
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)

Expand Down Expand Up @@ -249,6 +252,7 @@ def test_load_from_single_parquet(self, tokenizer, dataset_fixture):
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)

Expand Down Expand Up @@ -281,6 +285,7 @@ def test_load_from_single_json(self, tokenizer, dataset_fixture):
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)

Expand Down Expand Up @@ -365,6 +370,7 @@ def test_load_hub_with_revision_with_dpo(
"rl": "dpo",
"chat_template": "llama3",
"datasets": [ALPACA_MESSAGES_CONFIG_REVISION],
"dataset_processes": 4,
}
)

Expand Down Expand Up @@ -466,6 +472,7 @@ def test_loading_local_dataset_folder(self, tokenizer):
"type": "alpaca",
},
],
"dataset_processes": 4,
}
)

Expand Down
1 change: 1 addition & 0 deletions tests/test_exact_deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def cfg(self):
ALPACA_MESSAGES_CONFIG_REVISION,
ALPACA_MESSAGES_CONFIG_REVISION,
],
"dataset_processes": 4,
}
)
yield fixture
Expand Down
1 change: 1 addition & 0 deletions tests/test_packed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def test_lora_packing(self, temp_dir):
"type": "alpaca",
},
],
"dataset_processes": 4,
"num_epochs": 1,
"max_steps": 20,
"save_steps": 10,
Expand Down