diff --git a/pyproject.toml b/pyproject.toml index 69e38f51c3..06a7f60d5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,6 +126,9 @@ megatron-core = { path = "3rdparty/Megatron-LM/", editable = true } recipes = [ "nemo-run", ] +parquet = [ + "pyarrow>=14.0.0", +] tensor-inspect = [ "nvdlfw-inspect==0.2.1", ] diff --git a/src/megatron/bridge/data/builders/finetuning_dataset.py b/src/megatron/bridge/data/builders/finetuning_dataset.py index 0dff48fda1..8bd141cbe1 100644 --- a/src/megatron/bridge/data/builders/finetuning_dataset.py +++ b/src/megatron/bridge/data/builders/finetuning_dataset.py @@ -13,6 +13,7 @@ # limitations under the License. import logging +import warnings from pathlib import Path from typing import Any, Optional, Union @@ -20,6 +21,10 @@ from megatron.core.msc_utils import MultiStorageClientFeature from megatron.core.tokenizers.text.libraries import HuggingFaceTokenizer +from megatron.bridge.data.datasets.packed_parquet import ( + is_packed_parquet_spec, + resolve_packed_parquet_paths, +) from megatron.bridge.data.datasets.packed_sequence import PackedSequenceSpecs from megatron.bridge.data.datasets.sft import create_sft_dataset from megatron.bridge.utils.common_utils import get_rank_safe, print_rank_0 @@ -91,37 +96,100 @@ def prepare_data(self) -> None: self.prepare_packed_data() def prepare_packed_data(self) -> None: - """Prepare packed sequence data files if configured.""" - if self.packed_sequence_size > 0: - from megatron.bridge.data.datasets.packed_sequence import prepare_packed_sequence_data - - if not self.train_path_packed.is_file(): - print_rank_0(f"Preparing packed training data at {self.train_path_packed}") - prepare_packed_sequence_data( - input_path=self.train_path, - output_path=self.train_path_packed, - packed_sequence_size=self.packed_sequence_size, - tokenizer=self.tokenizer, - max_seq_length=self.seq_length, - seed=self.seed, - output_metadata_path=self.pack_metadata, - dataset_kwargs=self.dataset_kwargs, - pad_seq_to_mult=self._pad_seq_to_mult, - ) - - if self.do_validation and not self.validation_path_packed.is_file(): - print_rank_0(f"Preparing packed validation data at {self.validation_path_packed}") - prepare_packed_sequence_data( - input_path=self.validation_path, - output_path=self.validation_path_packed, - packed_sequence_size=self.packed_sequence_size, - tokenizer=self.tokenizer, - max_seq_length=self.seq_length, - seed=self.seed, - output_metadata_path=self.pack_metadata, - dataset_kwargs=self.dataset_kwargs, - pad_seq_to_mult=self._pad_seq_to_mult, - ) + """Prepare packed sequence data files if configured. + + Skips preparation if: + - packed_sequence_size <= 0 (packing disabled) + - packed data files already exist (parquet or legacy .npy) + """ + if self.packed_sequence_size <= 0: + return + + self._prepare_packed_split( + split_name="training", + packed_path=self.train_path_packed, + input_path=self.train_path, + ) + + if not self.do_validation: + return + + self._prepare_packed_split( + split_name="validation", + packed_path=self.validation_path_packed, + input_path=self.validation_path, + ) + + def _prepare_packed_split( + self, + split_name: str, + packed_path: Union[str, Path], + input_path: Path, + ) -> None: + """Prepare a single packed data split if it doesn't already exist. + + Args: + split_name: Name of the split (for logging). + packed_path: Output path for the packed data. + input_path: Input path to the raw dataset. + """ + from megatron.bridge.data.datasets.packed_sequence import prepare_packed_sequence_data + + if self._packed_path_exists(packed_path): + print_rank_0(f"Skipping packed {split_name} data preparation - already exists: {packed_path}") + return + + packed_path_str = str(packed_path) + if packed_path_str.lower().endswith(".npy"): + warnings.warn( + "Automatic .npy packed sequence preparation is deprecated and will be removed in the next release. " + "Please use packed parquet format instead.", + DeprecationWarning, + stacklevel=3, + ) + return + + print_rank_0(f"Preparing packed {split_name} data at {packed_path}") + prepare_packed_sequence_data( + input_path=input_path, + output_path=packed_path, + output_metadata_path=self.pack_metadata, + packed_sequence_size=self.packed_sequence_size, + tokenizer=self.tokenizer, + max_seq_length=self.seq_length, + seed=self.seed, + dataset_kwargs=self.dataset_kwargs, + pad_seq_to_mult=self._pad_seq_to_mult, + ) + + def _packed_path_exists(self, path: Union[str, Path]) -> bool: + """Check if a packed data path exists. + + For .npy files: check file exists + For packed parquet specs: check if resolution returns non-empty + + Args: + path: The path to check + + Returns: + True if the packed data exists + """ + path_str = str(path) + + # For packed parquet specs, check if resolution returns files + if is_packed_parquet_spec(path_str): + try: + resolved = resolve_packed_parquet_paths(path_str) + return len(resolved) > 0 + except ValueError: + return False + + # For .npy or other files, check existence + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + return msc.Path(path_str).is_file() + else: + return Path(path_str).is_file() def build(self) -> list[Optional[Any]]: """Build train, validation, and test datasets. @@ -190,7 +258,7 @@ def _create_dataset( """Create a single dataset instance (train, validation, or test). Args: - path: Path to the dataset file + path: Path to the dataset file or packed parquet spec pack_metadata_path: Path to the packed sequence metadata is_test: Whether this is a test dataset **kwargs: Additional arguments to pass to the dataset constructor @@ -198,17 +266,44 @@ def _create_dataset( Returns: The created dataset """ - if MultiStorageClientFeature.is_enabled(): - msc = MultiStorageClientFeature.import_package() - path_exists = msc.Path(path).exists() + path_str = str(path) + + # Check if path exists - handle packed parquet specs differently + if is_packed_parquet_spec(path_str): + # For packed parquet specs, check via resolution + try: + resolved = resolve_packed_parquet_paths(path_str) + path_exists = len(resolved) > 0 + except ValueError: + path_exists = False else: - path_exists = Path(path).exists() + # Standard file/path existence check + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + path_exists = msc.Path(path_str).exists() + else: + path_exists = Path(path_str).exists() if not path_exists: print_rank_0(f"Warning: Dataset path {path} does not exist") return None is_not_packing = self.packed_sequence_size <= 0 + + # For packed parquet from external sources, only pass metadata if pad_cu_seqlens is True + # This avoids "missing metadata" errors when using externally prepared packed data + effective_metadata_path = None + if not is_not_packing: + if self._pad_cu_seqlens: + # pad_cu_seqlens requires metadata + effective_metadata_path = pack_metadata_path + elif is_packed_parquet_spec(path_str): + # Externally prepared packed parquet without pad_cu_seqlens doesn't need metadata + effective_metadata_path = None + else: + # .npy files prepared by MB include metadata + effective_metadata_path = pack_metadata_path + return create_sft_dataset( path, tokenizer=self.tokenizer, @@ -216,7 +311,7 @@ def _create_dataset( memmap_workers=self.memmap_workers, seed=self.seed, is_test=is_test, - pack_metadata_file_path=None if is_not_packing else pack_metadata_path, + pack_metadata_file_path=effective_metadata_path, pad_cu_seqlens=False if is_not_packing else self._pad_cu_seqlens, pad_seq_to_mult=1 if is_not_packing else self._pad_seq_to_mult, **kwargs, @@ -269,7 +364,7 @@ def pack_metadata(self) -> Path: @property def train_path_packed(self) -> Path: - """Path to the packed training dataset file (.npy). + """Path to the packed training dataset file. Determined by `packed_sequence_specs` or defaults based on the `default_pack_path` and `packed_sequence_size`. @@ -283,13 +378,13 @@ def train_path_packed(self) -> Path: if self.packed_sequence_size > 0: if self.packed_sequence_specs.packed_train_data_path is not None: return self.packed_sequence_specs.packed_train_data_path - return self.default_pack_path / f"training_{self.packed_sequence_size}.npy" + return self.default_pack_path / f"training_{self.packed_sequence_size}.idx.parquet" else: raise ValueError("`train_path_packed` invalid since packed sequence size is not specified.") @property def validation_path_packed(self) -> Path: - """Path to the packed validation dataset file (.npy). + """Path to the packed validation dataset file. Determined by `packed_sequence_specs` or defaults based on the `default_pack_path` and `packed_sequence_size`. @@ -303,7 +398,7 @@ def validation_path_packed(self) -> Path: if self.packed_sequence_size > 0: if self.packed_sequence_specs.packed_val_data_path is not None: return self.packed_sequence_specs.packed_val_data_path - return self.default_pack_path / f"validation_{self.packed_sequence_size}.npy" + return self.default_pack_path / f"validation_{self.packed_sequence_size}.idx.parquet" else: raise ValueError("`validation_path_packed` invalid since packed sequence size is not specified.") diff --git a/src/megatron/bridge/data/datasets/packed_parquet.py b/src/megatron/bridge/data/datasets/packed_parquet.py new file mode 100644 index 0000000000..e2d801349f --- /dev/null +++ b/src/megatron/bridge/data/datasets/packed_parquet.py @@ -0,0 +1,648 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Packed Parquet dataset support for SFT training. + +This module provides GPTSFTPackedParquetDataset, which reads packed sequence data +from Parquet files as an alternative to the NumPy-based GPTSFTPackedDataset. + +Supports multiple files via: +- Single file: "data.idx.parquet", "shard_0.parquet" +- Glob pattern: "data*.idx.parquet", "shard_*.parquet" +- Directory: "/path/to/data/" (globs for *.parquet and *.pq) + +Key functions: +- is_packed_parquet_spec(): Check if a spec refers to packed Parquet data +- resolve_packed_parquet_paths(): Resolve a spec to actual file paths +""" + +from __future__ import annotations + +import bisect +import glob +import logging +import os +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np +from megatron.core.msc_utils import MultiStorageClientFeature + +from megatron.bridge.data.datasets.sft import GPTSFTPackedDataset + + +if TYPE_CHECKING: + from megatron.bridge.training.tokenizers.tokenizer import MegatronTokenizer + +logger = logging.getLogger(__name__) + +# Required columns in packed Parquet schema +REQUIRED_COLUMNS = {"input_ids", "seq_start_id", "loss_mask"} + + +def is_packed_parquet_file(path) -> bool: + """Check if a path refers to a packed Parquet file or pattern. + + Args: + path: A Path object or string path. + + Returns: + True if the path ends with .idx.parquet or .idx.pq, or contains a glob + pattern that would match such files. + """ + name = str(path).lower() + # Matches both direct files and glob patterns (e.g., "data*.idx.parquet") + # since both end with the extension. + return name.endswith(".idx.parquet") or name.endswith(".idx.pq") + + +def is_packed_parquet_spec(spec: str | Path) -> bool: + """Check if a spec refers to a packed Parquet source (file, directory, or glob). + + This predicate reflects what the dataset loader supports in packed mode: + - Single .parquet/.idx.parquet/.idx.pq files + - Glob patterns ending in .parquet/.idx.parquet/.idx.pq + - Directories containing parquet files + + Args: + spec: A path specification (file, directory, or glob pattern). + + Returns: + True if the spec could refer to packed Parquet data. + """ + spec_str = str(spec).lower() + + # Check for parquet file extensions (including glob patterns) + if spec_str.endswith(".parquet") or spec_str.endswith(".pq"): + return True + + # Check for glob patterns containing parquet extension + if "*" in spec_str or "?" in spec_str: + # Extract the pattern part after the last glob character + return ".parquet" in spec_str or ".pq" in spec_str + + # For directories, try to resolve to parquet files + # This is more robust than is_dir() on distributed filesystems (Lustre, S3, etc.) + try: + resolved = _resolve_parquet_paths(str(spec)) + return len(resolved) > 0 + except ValueError: + pass + + # Fallback: check if it's a directory using filesystem abstraction + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + msc_path = msc.Path(str(spec)) + return msc_path.is_dir() if hasattr(msc_path, "is_dir") else False + else: + return Path(spec).is_dir() + + +def _lazy_import_pyarrow(): + """Lazily import pyarrow and raise a clear error if not installed.""" + try: + import pyarrow + import pyarrow.parquet as pq + + return pyarrow, pq + except ImportError as e: + raise ImportError( + "pyarrow is required for packed Parquet datasets but is not installed. " + "Please reinstall megatron-bridge or run: pip install pyarrow>=14.0.0" + ) from e + + +def _is_parquet_file(path: str) -> bool: + """Check if a path refers to any Parquet file. + + Args: + path: A string path. + + Returns: + True if the path ends with .parquet or .pq (case-insensitive). + """ + name = path.lower() + return name.endswith(".parquet") or name.endswith(".pq") + + +def _resolve_parquet_paths(file_path: str) -> list[str]: + """Resolve a file path specification to a list of actual file paths. + + Supports: + - Single file: "data.idx.parquet", "shard_0.parquet" + - Glob pattern: "data*.idx.parquet", "shard_*.parquet" + - Directory: "/path/to/data/" (globs for *.parquet and *.pq) + + Args: + file_path: Path specification (file, glob pattern, or directory). + + Returns: + Sorted list of resolved file paths. + + Raises: + ValueError: If no matching files are found. + """ + path_str = str(file_path) + + # Check if it's a glob pattern + if "*" in path_str or "?" in path_str: + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + # MSC glob support - normalize to strings immediately + if hasattr(msc, "glob"): + paths = [str(p) for p in msc.glob(path_str)] + else: + # Fallback: try to use msc.Path with glob + # Use msc.Path to split parent/pattern to handle URIs correctly + msc_full_path = msc.Path(path_str) + parent = str(msc_full_path.parent) if hasattr(msc_full_path, "parent") else None + pattern = msc_full_path.name if hasattr(msc_full_path, "name") else None + + if parent is not None and pattern is not None: + msc_parent_path = msc.Path(parent) + if hasattr(msc_parent_path, "glob"): + paths = [str(p) for p in msc_parent_path.glob(pattern)] + else: + raise ValueError(f"MSC backend does not support glob operations for pattern: {path_str}") + else: + raise ValueError(f"MSC backend does not support glob operations for pattern: {path_str}") + else: + paths = glob.glob(path_str) + + # Filter to only parquet files (accepts both *.parquet and *.idx.parquet) + paths = [p for p in paths if _is_parquet_file(p)] + paths = sorted(paths) + + if not paths: + raise ValueError( + f"No Parquet files found matching pattern: {path_str}. Files must end with .parquet or .pq" + ) + return paths + + # Check if it's a directory + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + msc_path = msc.Path(path_str) + is_dir = msc_path.is_dir() if hasattr(msc_path, "is_dir") else False + else: + is_dir = Path(path_str).is_dir() + + if is_dir: + # Glob for parquet files in directory (accepts both *.parquet and *.idx.parquet) + paths = [] + for ext in ["*.parquet", "*.pq"]: + pattern = os.path.join(path_str, ext) + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + if hasattr(msc, "glob"): + # Normalize to strings immediately + paths.extend([str(p) for p in msc.glob(pattern)]) + elif hasattr(msc.Path(path_str), "glob"): + paths.extend([str(p) for p in msc.Path(path_str).glob(ext)]) + else: + paths.extend(glob.glob(pattern)) + + paths = sorted(set(paths)) + + if not paths: + raise ValueError(f"No Parquet files found in directory: {path_str}. Files must end with .parquet or .pq") + return paths + + # Single file - verify it has a parquet extension and exists + if not _is_parquet_file(path_str): + return [] + + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + exists = msc.Path(path_str).exists() + else: + exists = Path(path_str).exists() + + if not exists: + raise ValueError(f"Packed Parquet file not found: {path_str}") + + return [path_str] + + +def resolve_packed_parquet_paths(spec: str | Path) -> list[str]: + """Resolve a packed parquet spec to a list of shard file paths. + + Public wrapper around the internal _resolve_parquet_paths function. + Use this to validate and resolve packed parquet specs before dataset creation. + + Supports: + - Single file: "data.idx.parquet", "shard_0.parquet" + - Glob pattern: "data*.idx.parquet", "shard_*.parquet" + - Directory: "/path/to/data/" (globs for *.parquet and *.pq) + + Args: + spec: Path specification (file, glob pattern, or directory). + + Returns: + Sorted list of resolved file paths. + + Raises: + ValueError: If no matching files are found. + """ + return _resolve_parquet_paths(str(spec)) + + +def write_packed_parquet( + rows: list[dict], + output_path: str | Path, + row_group_size: int = 500, +) -> None: + """Write packed sequence data to a Parquet file. + + Args: + rows: List of dicts with keys 'input_ids', 'loss_mask', 'seq_start_id'. + This is the output format of fill_packing_strategy(). + output_path: Path to write the Parquet file. + row_group_size: Number of rows per row group (default 500). + """ + pa, pq = _lazy_import_pyarrow() + + table = pa.table( + { + "input_ids": [row["input_ids"] for row in rows], + "loss_mask": [row["loss_mask"] for row in rows], + "seq_start_id": [row["seq_start_id"] for row in rows], + } + ) + + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + buf = pa.BufferOutputStream() + pq.write_table(table, buf, row_group_size=row_group_size) + with msc.open(str(output_path), "wb") as f: + f.write(buf.getvalue().to_pybytes()) + else: + pq.write_table(table, str(output_path), row_group_size=row_group_size) + + +class GPTSFTPackedParquetDataset(GPTSFTPackedDataset): + """Dataset for packed sequences stored in Parquet format. + + This class reads packed training data from Parquet files with the naming convention + *.idx.parquet or *.idx.pq. It inherits from GPTSFTPackedDataset to reuse the + collate_fn() and loss-mask semantics. + + Supports multiple files via: + - Single file: "data.idx.parquet" + - Glob pattern: "data*.idx.parquet" or "shard_*.idx.pq" + - Directory: "/path/to/data/" (globs for *.idx.parquet and *.idx.pq) + + The Parquet file(s) must contain the following columns: + - input_ids: list - Token IDs for the packed sequence + - seq_start_id: list - Start offsets for each sub-sequence within the pack + - loss_mask: list - Per-token loss mask (0 or 1), same length as input_ids + + Example: + >>> # Single file + >>> dataset = GPTSFTPackedParquetDataset( + ... file_path="packed_data.idx.parquet", + ... tokenizer=tokenizer, + ... ) + >>> # Multiple files via glob + >>> dataset = GPTSFTPackedParquetDataset( + ... file_path="data/shard_*.idx.parquet", + ... tokenizer=tokenizer, + ... ) + """ + + def __init__( + self, + file_path: str, + tokenizer: "MegatronTokenizer", + return_cu_seqlen: bool = True, + pad_cu_seqlens: bool = False, + pack_metadata_file_path: str | None = None, + **kwargs, + ): + """Initialize the packed Parquet dataset. + + Args: + file_path: Path to packed Parquet file(s). Supports: + - Single file: "data.idx.parquet" + - Glob pattern: "data*.idx.parquet" + - Directory: "/path/to/data/" + tokenizer: The tokenizer to use. + return_cu_seqlen: Whether to return cu_seqlen for THD attention kernel. + pad_cu_seqlens: Whether to pad cu_seqlens for cudagraphs compatibility. + pack_metadata_file_path: Path to the metadata JSON file for pad_cu_seqlens. + **kwargs: Additional arguments passed to parent class. + """ + # Initialize Parquet-specific state before calling parent __init__ + # (parent calls _load_dataset which needs these) + self._file_path_spec: str = file_path # Original specification (may be glob) + self._parquet_paths: list[str] = [] # Resolved list of files + self._num_rows: int = 0 # Total rows across all files + self._file_offsets: list[int] = [] # Cumulative row counts: [0, rows_file0, rows_file0+rows_file1, ...] + self._file_row_group_offsets: list[list[int]] = [] # Row group offsets per file + + # Lazy reader state (opened in worker processes after fork) + # Maps file_idx -> (ParquetFile, handle) + self._parquet_files: dict[int, tuple] = {} + self._cached_file_idx: int | None = None + self._cached_row_group_id: int | None = None + self._cached_row_group_table = None + + # Call parent __init__ which will call _load_dataset() and _build_samples_mapping() + super().__init__( + file_path=file_path, + tokenizer=tokenizer, + return_cu_seqlen=return_cu_seqlen, + pad_cu_seqlens=pad_cu_seqlens, + pack_metadata_file_path=pack_metadata_file_path, + **kwargs, + ) + + def _load_dataset(self): + """Load Parquet metadata from all files and validate schemas. + + This method: + 1. Resolves the file path specification to actual files + 2. Reads metadata from each file (not actual data) + 3. Validates schemas contain required columns + 4. Builds cumulative indices for efficient row lookups + + The actual Parquet files are opened lazily in _ensure_reader() to survive + DataLoader worker forking. + """ + pyarrow, pq = _lazy_import_pyarrow() + + # Resolve file paths + self._parquet_paths = _resolve_parquet_paths(self._file_path_spec) + + logger.info(f"Resolved {len(self._parquet_paths)} packed Parquet file(s) from: {self._file_path_spec}") + + # Build cumulative offsets + self._file_offsets = [0] + self._file_row_group_offsets = [] + + for file_idx, parquet_path in enumerate(self._parquet_paths): + # Read metadata only (not actual data) + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + handle = msc.open(str(parquet_path), "rb") + try: + if hasattr(handle, "seekable") and handle.seekable(): + metadata = pq.read_metadata(handle) + handle.seek(0) + schema = pq.read_schema(handle) + else: + content = handle.read() + buffer = pyarrow.BufferReader(content) + pf = pq.ParquetFile(buffer) + metadata = pf.metadata + schema = pf.schema_arrow + finally: + handle.close() + else: + metadata = pq.read_metadata(parquet_path) + schema = pq.read_schema(parquet_path) + + # Validate schema on every file to catch malformed shards early + schema_columns = set(schema.names) + missing_columns = REQUIRED_COLUMNS - schema_columns + if missing_columns: + raise ValueError( + f"Packed Parquet file '{parquet_path}' is missing required columns: {missing_columns}. " + f"Required columns are: {REQUIRED_COLUMNS}. " + f"Found columns: {schema_columns}" + ) + + # Build row group offsets for this file + row_group_offsets = [0] + for i in range(metadata.num_row_groups): + row_group_offsets.append(row_group_offsets[-1] + metadata.row_group(i).num_rows) + self._file_row_group_offsets.append(row_group_offsets) + + # Update cumulative file offset + file_rows = metadata.num_rows + self._file_offsets.append(self._file_offsets[-1] + file_rows) + + logger.debug( + f" File {file_idx}: {parquet_path}, {file_rows} rows in {metadata.num_row_groups} row groups" + ) + + self._num_rows = self._file_offsets[-1] + + # Validate dataset is not empty + if self._num_rows == 0: + raise ValueError(f"Packed Parquet dataset is empty (0 rows) for path: {self._file_path_spec}") + + logger.info( + f"Loaded packed Parquet dataset: {self._num_rows} total rows across {len(self._parquet_paths)} file(s)" + ) + + @staticmethod + def validate_row(idx: int, input_ids: list, loss_mask: list, seq_start_id: list) -> None: + """Validate packed row invariants. + + This is NOT called in the training hot path for performance reasons. + Use it during data preparation or for debugging. + + Args: + idx: Row index (for error messages). + input_ids: Token IDs for the packed sequence. + loss_mask: Per-token loss mask. + seq_start_id: Start offsets for each sub-sequence. + + Raises: + ValueError: If any invariant is violated. + """ + if len(loss_mask) != len(input_ids): + raise ValueError(f"Row {idx}: loss_mask length ({len(loss_mask)}) != input_ids length ({len(input_ids)})") + + if not seq_start_id or seq_start_id[0] != 0: + raise ValueError( + f"Row {idx}: seq_start_id must start with 0, got {seq_start_id[:5] if seq_start_id else []}" + ) + + for i, start in enumerate(seq_start_id): + if start >= len(input_ids): + raise ValueError(f"Row {idx}: seq_start_id[{i}]={start} >= len(input_ids)={len(input_ids)}") + if i > 0 and start < seq_start_id[i - 1]: + raise ValueError( + f"Row {idx}: seq_start_id is not non-decreasing at index {i}: {seq_start_id[i - 1]} > {start}" + ) + + def _ensure_reader(self, file_idx: int): + """Lazily open a Parquet file for reading. + + Args: + file_idx: Index of the file in self._parquet_paths. + + This method is called before accessing data and creates the ParquetFile + reader if it doesn't exist. This lazy initialization ensures the reader + survives DataLoader worker forking (each worker creates its own readers). + """ + if file_idx in self._parquet_files: + return self._parquet_files[file_idx][0] + + pyarrow, pq = _lazy_import_pyarrow() + parquet_path = self._parquet_paths[file_idx] + + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + handle = msc.open(str(parquet_path), "rb") + + if hasattr(handle, "seekable") and handle.seekable(): + pf = pq.ParquetFile(handle) + self._parquet_files[file_idx] = (pf, handle) + else: + # MVP fallback: load entire file into memory for non-seekable streams + logger.warning(f"MSC stream is not seekable, loading entire Parquet file into memory: {parquet_path}") + content = handle.read() + handle.close() + buffer = pyarrow.BufferReader(content) + pf = pq.ParquetFile(buffer) + self._parquet_files[file_idx] = (pf, None) + else: + pf = pq.ParquetFile(parquet_path) + self._parquet_files[file_idx] = (pf, None) + + return self._parquet_files[file_idx][0] + + def close(self) -> None: + """Close all open Parquet file handles. + + This method should be called when the dataset is no longer needed to + release file handles, especially when using MSC backends. It is also + called automatically by __del__. + """ + parquet_files = getattr(self, "_parquet_files", None) + if parquet_files is None: + return + + for file_idx, (pf, handle) in list(parquet_files.items()): + if handle is not None: + try: + handle.close() + except Exception: + pass # Best effort cleanup + # Also close ParquetFile if it has a close method + if hasattr(pf, "close"): + try: + pf.close() + except Exception: + pass + + self._parquet_files.clear() + self._cached_row_group_table = None + self._cached_file_idx = None + self._cached_row_group_id = None + + def __del__(self) -> None: + """Cleanup on deletion.""" + self.close() + + def _build_samples_mapping(self): + """Build epoch-level sample mapping for shuffling. + + Mirrors GPTSFTPackedDataset._build_samples_mapping() exactly, + using self._num_rows instead of len(self.indexed_dataset). + """ + if self.max_num_samples is not None: + dataset_len = self._num_rows + max_num_epochs = np.ceil(self.max_num_samples / dataset_len) + indices = np.arange(dataset_len)[None, :].repeat(max_num_epochs, axis=0) + [np.random.shuffle(x) for x in indices] + self.samples_mapping = indices.reshape(1, -1).squeeze()[: self.max_num_samples] + else: + self.samples_mapping = None + + def __len__(self): + """Return the number of samples in the dataset.""" + if self.samples_mapping is not None: + return len(self.samples_mapping) + return self._num_rows + + def _locate_row(self, global_idx: int) -> tuple[int, int, int]: + """Map a global row index to (file_idx, row_group_id, row_in_group). + + Args: + global_idx: Global row index across all files. + + Returns: + Tuple of (file_idx, row_group_id, row_in_group). + """ + # Find which file contains this row + file_idx = bisect.bisect_right(self._file_offsets, global_idx) - 1 + row_in_file = global_idx - self._file_offsets[file_idx] + + # Find which row group within the file + row_group_offsets = self._file_row_group_offsets[file_idx] + row_group_id = bisect.bisect_right(row_group_offsets, row_in_file) - 1 + row_in_group = row_in_file - row_group_offsets[row_group_id] + + return file_idx, row_group_id, row_in_group + + def __getitem__(self, idx: int) -> dict: + """Get a packed sample by index. + + Args: + idx: Sample index. If samples_mapping exists, this is mapped to the + actual row index. Negative indices return samples with zeroed loss_mask. + + Returns: + dict with keys: + - input_ids: list[int] - Token IDs + - seq_boundaries: list[int] - Sequence boundaries (derived from seq_start_id) + - loss_mask: list[int] - Per-token loss mask + """ + # Apply sample mapping if exists + if self.samples_mapping is not None: + idx = self.samples_mapping[idx] + + # Handle negative indices (padding samples) + # Use wrap-around semantics matching parent GPTSFTPackedDataset behavior + is_padding_sample = idx < 0 + if is_padding_sample: + idx = self._num_rows + idx # -1 -> last row, -N -> Nth from end + + # Locate the row across files and row groups + file_idx, row_group_id, row_in_group = self._locate_row(idx) + + # Ensure reader is initialized for this file + pf = self._ensure_reader(file_idx) + + # Read row group with caching + cache_key = (file_idx, row_group_id) + if (self._cached_file_idx, self._cached_row_group_id) != cache_key: + self._cached_row_group_table = pf.read_row_group( + row_group_id, columns=["input_ids", "seq_start_id", "loss_mask"] + ) + self._cached_file_idx = file_idx + self._cached_row_group_id = row_group_id + + # Extract row values + table = self._cached_row_group_table + input_ids = table.column("input_ids")[row_in_group].as_py() + seq_start_id = table.column("seq_start_id")[row_in_group].as_py() + loss_mask = table.column("loss_mask")[row_in_group].as_py() + + # Compute derived field + seq_boundaries = seq_start_id + [len(input_ids)] + + # For padding samples, zero out the loss mask + if is_padding_sample: + loss_mask = [0] * len(loss_mask) + + return { + "input_ids": input_ids, + "seq_boundaries": seq_boundaries, + "loss_mask": loss_mask, + } diff --git a/src/megatron/bridge/data/datasets/packed_sequence.py b/src/megatron/bridge/data/datasets/packed_sequence.py index cfe83c575a..878da5f9c3 100644 --- a/src/megatron/bridge/data/datasets/packed_sequence.py +++ b/src/megatron/bridge/data/datasets/packed_sequence.py @@ -13,12 +13,17 @@ # limitations under the License. import json import logging +import warnings from dataclasses import dataclass from pathlib import Path import numpy as np from megatron.core.msc_utils import MultiStorageClientFeature +from megatron.bridge.data.datasets.packed_parquet import ( + is_packed_parquet_spec, + resolve_packed_parquet_paths, +) from megatron.bridge.data.datasets.packing_utils import create_hist, create_packing_strategy, fill_packing_strategy from megatron.bridge.data.datasets.sft import create_sft_dataset from megatron.bridge.training.tokenizers.tokenizer import MegatronTokenizer @@ -113,7 +118,9 @@ def pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id): data[key] = val return - ceil_to_nearest = lambda n, m: (n + m - 1) // m * m + def ceil_to_nearest(n, m): + return (n + m - 1) // m * m + for data in dataset: max_length_to_pad = min(max_seq_length, ceil_to_nearest(len(data["input_ids"]), pad_seq_length_to_mult)) pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id) @@ -169,11 +176,18 @@ def prepare_packed_sequence_data( output_data = fill_packing_strategy(assignments, sequences, packed_sequence_size, tokenizer.eos_id) # save output data - if MultiStorageClientFeature.is_enabled(): - msc = MultiStorageClientFeature.import_package() - msc.numpy.save(output_path, output_data) + output_path_str = str(output_path) + if output_path_str.lower().endswith((".parquet", ".pq")): + from megatron.bridge.data.datasets.packed_parquet import write_packed_parquet + + write_packed_parquet(output_data, output_path) else: - np.save(output_path, output_data) + # Legacy .npy format + if MultiStorageClientFeature.is_enabled(): + msc = MultiStorageClientFeature.import_package() + msc.numpy.save(output_path, output_data) + else: + np.save(output_path, output_data) # save packing metadata, packing_metadata is appended to the packing file if it exists if output_metadata_path is not None: @@ -247,30 +261,66 @@ class PackedSequenceSpecs: def __post_init__(self): if self.packed_train_data_path is not None: - if MultiStorageClientFeature.is_enabled(): - msc = MultiStorageClientFeature.import_package() - self.packed_train_data_path = msc.Path(self.packed_train_data_path) - else: - self.packed_train_data_path = Path(self.packed_train_data_path) - assert self.packed_train_data_path.suffix == ".npy", ( - f"packed training data file must be a .npy file: {self.packed_train_data_path}" - ) - assert self.packed_train_data_path.exists(), ( - f"packed training data file does not exist: {self.packed_train_data_path}" - ) + self._validate_packed_path("packed_train_data_path", self.packed_train_data_path) if self.packed_val_data_path is not None: + self._validate_packed_path("packed_val_data_path", self.packed_val_data_path) + + if self.pad_seq_to_mult is not None and self.pad_seq_to_mult <= 0: + raise ValueError("pad_seq_to_mult must be a positive integer when provided.") + + def _validate_packed_path(self, attr_name: str, path_value: str) -> None: + """Validate a packed data path and store it appropriately. + + For .npy files: strict validation with Path.exists() + For packed parquet specs: validate via resolution (supports dirs/globs) + + Args: + attr_name: The attribute name being validated (for error messages) + path_value: The path value to validate + + Raises: + FileNotFoundError: If the path does not exist or resolves to no files + ValueError: If the path format is invalid + """ + path_str = str(path_value) + + # Check if it's an .npy file (legacy format) + if path_str.lower().endswith(".npy"): + warnings.warn( + f"The .npy packed sequence format is deprecated and will be removed in the next release. " + f"Please use packed parquet format instead. Path: {path_str}", + DeprecationWarning, + stacklevel=2, + ) if MultiStorageClientFeature.is_enabled(): msc = MultiStorageClientFeature.import_package() - self.packed_val_data_path = msc.Path(self.packed_val_data_path) + path_obj = msc.Path(path_str) else: - self.packed_val_data_path = Path(self.packed_val_data_path) - assert self.packed_val_data_path.suffix == ".npy", ( - f"packed validation data file must be a .npy file: {self.packed_val_data_path}" - ) - assert self.packed_val_data_path.exists(), ( - f"packed validation data file does not exist: {self.packed_val_data_path}" - ) + path_obj = Path(path_str) - if self.pad_seq_to_mult is not None and self.pad_seq_to_mult <= 0: - raise ValueError("pad_seq_to_mult must be a positive integer when provided.") + if not path_obj.exists(): + raise FileNotFoundError(f"{attr_name} file does not exist: {path_str}") + setattr(self, attr_name, path_obj) + return + + # Check if it's a packed parquet spec (file/dir/glob) + if is_packed_parquet_spec(path_str): + # Validate by resolving - this checks that files actually exist + try: + resolved_paths = resolve_packed_parquet_paths(path_str) + if len(resolved_paths) == 0: + raise FileNotFoundError(f"{attr_name} resolved to no files: {path_str}") + except ValueError as e: + raise FileNotFoundError(f"{attr_name} could not be resolved: {path_str}. Error: {e}") from e + + # Store the original string spec (not Path) to preserve globs + # The dataset loader will handle resolution + setattr(self, attr_name, path_str) + return + + # Neither .npy nor valid packed parquet spec + raise ValueError( + f"{attr_name} must be a .npy file or a packed parquet spec " + f"(file/directory/glob ending in .parquet or .pq): {path_str}" + ) diff --git a/src/megatron/bridge/data/datasets/sft.py b/src/megatron/bridge/data/datasets/sft.py index f29900223b..eb498fdad1 100644 --- a/src/megatron/bridge/data/datasets/sft.py +++ b/src/megatron/bridge/data/datasets/sft.py @@ -78,7 +78,7 @@ def get_dataset_root(name: str) -> Path: def create_sft_dataset( - path: Path, + path: str | Path, tokenizer: "MegatronTokenizer", seq_length: int = 2048, add_bos: bool = False, @@ -96,7 +96,7 @@ def create_sft_dataset( hf_dataset: bool = False, global_sample_mapping: bool = False, get_attention_mask_from_fusion: bool = True, - pack_metadata_file_path: Path = None, + pack_metadata_file_path: Path | str | None = None, pad_cu_seqlens: bool = False, pad_seq_to_mult: int = 1, chat: bool = False, @@ -111,8 +111,19 @@ def create_sft_dataset( input parameters. It can create standard SFT datasets, chat-specific datasets, or packed sequence datasets. + Dataset selection logic: + 1. If path ends with .npy: GPTSFTPackedDataset (legacy packed format) + 2. If path is a packed parquet spec (file/dir/glob ending in .parquet/.pq, + or a directory): GPTSFTPackedParquetDataset + - Note: Selection is based on path pattern, not pack_metadata_file_path + - Schema validation (REQUIRED_COLUMNS) will fast-fail for non-packed files + 3. If chat=True: GPTSFTChatDataset + 4. Otherwise: GPTSFTDataset + Args: - path (Path): Path to the dataset file. For packed datasets, this should be a .npy file. + path (str | Path): Path to the dataset file or packed parquet spec (file/dir/glob). + For packed datasets, this can be a .npy file, a .parquet file, a directory + containing parquet files, or a glob pattern. tokenizer (MegatronTokenizer): The tokenizer to use for tokenizing the data. seq_length (int, optional): Maximum sequence length for each example. Defaults to 2048. add_bos (bool, optional): Whether to add a beginning-of-sentence token. Defaults to False. @@ -135,8 +146,8 @@ def create_sft_dataset( or shuffle within each epoch. Defaults to False. get_attention_mask_from_fusion (bool): if true, lets attention kernel handle creation of causal mask instead of adding it to the batch dict. - pack_metadata_file_path (Path, optional): Path to the metadata file for packed datasets. - Required if `pad_cu_seqlens` is True. Defaults to None. + pack_metadata_file_path (Path | str | None, optional): Path to the metadata file for packed datasets. + When provided, enables packed mode. Required if `pad_cu_seqlens` is True. Defaults to None. pad_cu_seqlens (bool, optional): Whether to pad `cu_seqlens` for packed datasets, required for cudagraphs. Defaults to False. chat (bool, optional): If True, creates a `GPTSFTChatDataset`. Defaults to False. @@ -150,9 +161,11 @@ def create_sft_dataset( Returns: GPTSFTDataset | GPTSFTChatDataset | GPTSFTPackedDataset: An instance of the appropriate SFT dataset class. """ + # Normalize path to string for consistent handling + path_str = str(path) gpt_sft_dataset_kwargs = { - "file_path": str(path), + "file_path": path_str, "tokenizer": tokenizer, "max_seq_length": seq_length, "memmap_workers": memmap_workers, @@ -172,7 +185,8 @@ def create_sft_dataset( "get_attention_mask_from_fusion": get_attention_mask_from_fusion, } - if path.suffix == ".npy": + # Check for .npy packed dataset (legacy format) + if path_str.lower().endswith(".npy"): return GPTSFTPackedDataset( pack_metadata_file_path=pack_metadata_file_path, pad_cu_seqlens=pad_cu_seqlens, @@ -180,6 +194,26 @@ def create_sft_dataset( **gpt_sft_dataset_kwargs, **kwargs, ) + + # Lazy import to avoid circular dependency (packed_parquet imports from sft) + from megatron.bridge.data.datasets.packed_parquet import ( + GPTSFTPackedParquetDataset, + is_packed_parquet_spec, + ) + + # Select GPTSFTPackedParquetDataset for any packed parquet spec (file/dir/glob) + # This is determined purely by path pattern, NOT by pack_metadata_file_path. + # Rationale: + # - Directory/glob specs clearly indicate packed parquet shards + # - Schema validation (REQUIRED_COLUMNS) will fast-fail if files aren't packed format + # - This allows externally-prepared packed data to work without requiring MB metadata + if is_packed_parquet_spec(path_str): + return GPTSFTPackedParquetDataset( + pack_metadata_file_path=pack_metadata_file_path, + pad_cu_seqlens=pad_cu_seqlens, + **gpt_sft_dataset_kwargs, + **kwargs, + ) elif chat: return GPTSFTChatDataset( **gpt_sft_dataset_kwargs, diff --git a/tests/unit_tests/data/datasets/test_chat_template.py b/tests/unit_tests/data/datasets/test_chat_template.py index 4476cf0476..6c5a190457 100644 --- a/tests/unit_tests/data/datasets/test_chat_template.py +++ b/tests/unit_tests/data/datasets/test_chat_template.py @@ -449,6 +449,138 @@ def test_create_packed_dataset_priority(self, mock_packed_class): # Verify GPTSFTPackedDataset was called (not GPTSFTChatDataset) mock_packed_class.assert_called_once() + @patch("megatron.bridge.data.datasets.sft.GPTSFTPackedParquetDataset") + def test_create_packed_parquet_dataset_idx_parquet(self, mock_parquet_class): + """Test that .idx.parquet files create GPTSFTPackedParquetDataset.""" + from pathlib import Path + + mock_tokenizer = MagicMock() + mock_parquet_class.return_value = MagicMock() + + create_sft_dataset( + path=Path("test.idx.parquet"), + tokenizer=mock_tokenizer, + ) + + # Verify GPTSFTPackedParquetDataset was called + mock_parquet_class.assert_called_once() + + @patch("megatron.bridge.data.datasets.sft.GPTSFTPackedParquetDataset") + def test_create_packed_parquet_dataset_idx_pq(self, mock_parquet_class): + """Test that .idx.pq files create GPTSFTPackedParquetDataset.""" + from pathlib import Path + + mock_tokenizer = MagicMock() + mock_parquet_class.return_value = MagicMock() + + create_sft_dataset( + path=Path("test.idx.pq"), + tokenizer=mock_tokenizer, + ) + + # Verify GPTSFTPackedParquetDataset was called + mock_parquet_class.assert_called_once() + + @patch("megatron.bridge.data.datasets.sft.GPTSFTPackedParquetDataset") + def test_create_packed_parquet_dataset_priority_over_chat(self, mock_parquet_class): + """Test that packed Parquet files take precedence over chat=True.""" + from pathlib import Path + + mock_tokenizer = MagicMock() + mock_parquet_class.return_value = MagicMock() + + create_sft_dataset( + path=Path("test.idx.parquet"), + tokenizer=mock_tokenizer, + chat=True, # Should be ignored for packed Parquet files + use_hf_tokenizer_chat_template=True, + ) + + # Verify GPTSFTPackedParquetDataset was called (not GPTSFTChatDataset) + mock_parquet_class.assert_called_once() + + @patch("megatron.bridge.data.datasets.sft.GPTSFTChatDataset") + def test_regular_parquet_not_routed_to_packed(self, mock_chat_class): + """Test that regular .parquet files (without .idx.) are NOT routed to packed dataset.""" + from pathlib import Path + + mock_tokenizer = MagicMock() + mock_chat_class.return_value = MagicMock() + + create_sft_dataset( + path=Path("test.parquet"), # No .idx. prefix + tokenizer=mock_tokenizer, + chat=True, + use_hf_tokenizer_chat_template=True, + ) + + # Verify GPTSFTChatDataset was called (regular parquet goes to chat/default) + mock_chat_class.assert_called_once() + + @patch("megatron.bridge.data.datasets.sft.GPTSFTPackedParquetDataset") + def test_create_packed_parquet_glob_pattern(self, mock_parquet_class): + """Test that glob patterns like data*.idx.parquet route to GPTSFTPackedParquetDataset.""" + from pathlib import Path + + mock_tokenizer = MagicMock() + mock_parquet_class.return_value = MagicMock() + + create_sft_dataset( + path=Path("data/shard_*.idx.parquet"), # Glob pattern + tokenizer=mock_tokenizer, + ) + + # Verify GPTSFTPackedParquetDataset was called + mock_parquet_class.assert_called_once() + + +class TestIsPackedParquetFile: + """Test cases for is_packed_parquet_file utility function.""" + + def test_single_file_idx_parquet(self): + """Test detection of single .idx.parquet file.""" + from megatron.bridge.data.datasets.packed_parquet import is_packed_parquet_file + + assert is_packed_parquet_file("data.idx.parquet") is True + assert is_packed_parquet_file("/path/to/data.idx.parquet") is True + + def test_single_file_idx_pq(self): + """Test detection of single .idx.pq file.""" + from megatron.bridge.data.datasets.packed_parquet import is_packed_parquet_file + + assert is_packed_parquet_file("data.idx.pq") is True + assert is_packed_parquet_file("/path/to/data.idx.pq") is True + + def test_glob_pattern_idx_parquet(self): + """Test detection of glob patterns for .idx.parquet.""" + from megatron.bridge.data.datasets.packed_parquet import is_packed_parquet_file + + assert is_packed_parquet_file("data*.idx.parquet") is True + assert is_packed_parquet_file("shard_?.idx.parquet") is True + assert is_packed_parquet_file("/path/to/data*.idx.parquet") is True + + def test_glob_pattern_idx_pq(self): + """Test detection of glob patterns for .idx.pq.""" + from megatron.bridge.data.datasets.packed_parquet import is_packed_parquet_file + + assert is_packed_parquet_file("data*.idx.pq") is True + assert is_packed_parquet_file("shard_?.idx.pq") is True + + def test_regular_parquet_not_detected(self): + """Test that regular .parquet files are not detected as packed.""" + from megatron.bridge.data.datasets.packed_parquet import is_packed_parquet_file + + assert is_packed_parquet_file("data.parquet") is False + assert is_packed_parquet_file("data*.parquet") is False + assert is_packed_parquet_file("/path/to/data.parquet") is False + + def test_case_insensitive(self): + """Test case-insensitive detection.""" + from megatron.bridge.data.datasets.packed_parquet import is_packed_parquet_file + + assert is_packed_parquet_file("DATA.IDX.PARQUET") is True + assert is_packed_parquet_file("Data.Idx.Pq") is True + class TestPackedDatasetNaNFix: """Test cases for NaN fix in packed dataset collate_fn.""" diff --git a/tests/unit_tests/data/datasets/test_packed_parquet.py b/tests/unit_tests/data/datasets/test_packed_parquet.py new file mode 100644 index 0000000000..36b5acc1ab --- /dev/null +++ b/tests/unit_tests/data/datasets/test_packed_parquet.py @@ -0,0 +1,383 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for GPTSFTPackedParquetDataset. + +These tests create real Parquet files and exercise the dataset end-to-end, +covering _locate_row, row-group caching, multi-file support, schema validation, +and the validate_row helper. +""" + +from __future__ import annotations + +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + + +# Optional dependency — skip the entire module when pyarrow is missing. +pa = pytest.importorskip("pyarrow") +pq = pytest.importorskip("pyarrow.parquet") + +from megatron.bridge.data.datasets.packed_parquet import ( + GPTSFTPackedParquetDataset, + _resolve_parquet_paths, + is_packed_parquet_file, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_tokenizer_mock(): + """Return a minimal tokenizer mock sufficient for GPTSFTPackedDataset.""" + tok = MagicMock() + tok.eos_id = 0 + tok.eod = 0 + tok.pad_id = 0 + return tok + + +def _make_packed_row(n_tokens: int = 64, n_seqs: int = 2): + """Build a single packed-row dict. + + Returns input_ids, loss_mask, and seq_start_id with valid invariants. + """ + assert n_seqs >= 1 + tokens_per_seq = n_tokens // n_seqs + seq_start_id = [i * tokens_per_seq for i in range(n_seqs)] + return { + "input_ids": list(range(1, n_tokens + 1)), + "loss_mask": [1] * n_tokens, + "seq_start_id": seq_start_id, + } + + +def _write_parquet(path: str | Path, rows: list[dict], row_group_size: int = 500): + """Write a list of packed-row dicts to a Parquet file.""" + table = pa.table( + { + "input_ids": [row["input_ids"] for row in rows], + "loss_mask": [row["loss_mask"] for row in rows], + "seq_start_id": [row["seq_start_id"] for row in rows], + } + ) + pq.write_table(table, str(path), row_group_size=row_group_size) + + +# --------------------------------------------------------------------------- +# Tests: is_packed_parquet_file +# --------------------------------------------------------------------------- + + +class TestIsPackedParquetFile: + def test_direct_idx_parquet(self): + assert is_packed_parquet_file("data.idx.parquet") is True + + def test_direct_idx_pq(self): + assert is_packed_parquet_file("data.idx.pq") is True + + def test_glob_pattern(self): + assert is_packed_parquet_file("shard_*.idx.parquet") is True + + def test_regular_parquet_rejected(self): + assert is_packed_parquet_file("data.parquet") is False + + def test_case_insensitive(self): + assert is_packed_parquet_file("DATA.IDX.PARQUET") is True + + +# --------------------------------------------------------------------------- +# Tests: _resolve_parquet_paths +# --------------------------------------------------------------------------- + + +class TestResolveParquetPaths: + def test_single_file(self, tmp_path): + f = tmp_path / "data.idx.parquet" + _write_parquet(f, [_make_packed_row()]) + paths = _resolve_parquet_paths(str(f)) + assert paths == [str(f)] + + def test_glob_pattern(self, tmp_path): + for i in range(3): + f = tmp_path / f"shard_{i:03d}.idx.parquet" + _write_parquet(f, [_make_packed_row()]) + paths = _resolve_parquet_paths(str(tmp_path / "shard_*.idx.parquet")) + assert len(paths) == 3 + assert paths == sorted(paths) + + def test_directory(self, tmp_path): + for i in range(2): + _write_parquet(tmp_path / f"s{i}.idx.parquet", [_make_packed_row()]) + paths = _resolve_parquet_paths(str(tmp_path)) + assert len(paths) == 2 + + def test_missing_file_raises(self, tmp_path): + with pytest.raises(ValueError, match="not found"): + _resolve_parquet_paths(str(tmp_path / "nonexistent.idx.parquet")) + + def test_empty_dir_raises(self, tmp_path): + with pytest.raises(ValueError, match="No packed Parquet files found"): + _resolve_parquet_paths(str(tmp_path)) + + +# --------------------------------------------------------------------------- +# Tests: GPTSFTPackedParquetDataset +# --------------------------------------------------------------------------- + + +def _make_dataset(file_path, max_num_samples=None, **kwargs): + """Construct a GPTSFTPackedParquetDataset with minimal config.""" + return GPTSFTPackedParquetDataset( + file_path=str(file_path), + tokenizer=_make_tokenizer_mock(), + max_seq_length=4096, + max_num_samples=max_num_samples, + pad_to_max_length=False, + add_bos=False, + add_eos=False, + add_sep=False, + seed=42, + answer_only_loss=True, + truncation_field="input", + prompt_template="{input} {output}", + return_cu_seqlen=True, + **kwargs, + ) + + +class TestPackedParquetDatasetSingleFile: + """Tests using a single Parquet file.""" + + @pytest.fixture() + def parquet_file(self, tmp_path): + rows = [_make_packed_row(n_tokens=64, n_seqs=2) for _ in range(10)] + path = tmp_path / "data.idx.parquet" + _write_parquet(path, rows, row_group_size=5) + return path + + def test_len(self, parquet_file): + ds = _make_dataset(parquet_file) + assert len(ds) == 10 + + def test_getitem_returns_expected_keys(self, parquet_file): + ds = _make_dataset(parquet_file) + sample = ds[0] + assert "input_ids" in sample + assert "loss_mask" in sample + assert "seq_boundaries" in sample + + def test_getitem_seq_boundaries(self, parquet_file): + ds = _make_dataset(parquet_file) + sample = ds[0] + # seq_start_id was [0, 32], so boundaries should be [0, 32, 64] + assert sample["seq_boundaries"] == [0, 32, 64] + + def test_getitem_all_rows(self, parquet_file): + ds = _make_dataset(parquet_file) + for i in range(len(ds)): + sample = ds[i] + assert len(sample["input_ids"]) == 64 + assert len(sample["loss_mask"]) == 64 + + def test_negative_index_zeroes_loss_mask(self, parquet_file): + ds = _make_dataset(parquet_file) + sample = ds[-1] + assert all(m == 0 for m in sample["loss_mask"]) + + def test_max_num_samples(self, parquet_file): + ds = _make_dataset(parquet_file, max_num_samples=5) + assert len(ds) == 5 + + def test_oversampling(self, parquet_file): + ds = _make_dataset(parquet_file, max_num_samples=25) + assert len(ds) == 25 + # Should be able to iterate all + for i in range(len(ds)): + ds[i] + + +class TestPackedParquetDatasetMultiFile: + """Tests using multiple Parquet files.""" + + @pytest.fixture() + def multi_parquet(self, tmp_path): + # 3 shards with 5, 10, 7 rows respectively + counts = [5, 10, 7] + for i, n in enumerate(counts): + rows = [_make_packed_row(n_tokens=32, n_seqs=1) for _ in range(n)] + _write_parquet(tmp_path / f"shard_{i:03d}.idx.parquet", rows, row_group_size=4) + return tmp_path + + def test_total_len(self, multi_parquet): + ds = _make_dataset(str(multi_parquet / "shard_*.idx.parquet")) + assert len(ds) == 22 + + def test_getitem_across_files(self, multi_parquet): + ds = _make_dataset(str(multi_parquet / "shard_*.idx.parquet")) + for i in range(len(ds)): + sample = ds[i] + assert len(sample["input_ids"]) == 32 + + def test_locate_row_boundaries(self, multi_parquet): + ds = _make_dataset(str(multi_parquet / "shard_*.idx.parquet")) + # Row 0 -> file 0, row 4 -> file 0, row 5 -> file 1, row 15 -> file 2 + file_idx_0, _, _ = ds._locate_row(0) + file_idx_4, _, _ = ds._locate_row(4) + file_idx_5, _, _ = ds._locate_row(5) + file_idx_15, _, _ = ds._locate_row(15) + assert file_idx_0 == 0 + assert file_idx_4 == 0 + assert file_idx_5 == 1 + assert file_idx_15 == 2 + + +class TestPackedParquetDatasetRowGroupCache: + """Tests for row-group caching behavior.""" + + @pytest.fixture() + def parquet_file(self, tmp_path): + rows = [_make_packed_row(n_tokens=16, n_seqs=1) for _ in range(20)] + path = tmp_path / "data.idx.parquet" + _write_parquet(path, rows, row_group_size=5) + return path + + def test_cache_reuse_within_row_group(self, parquet_file): + ds = _make_dataset(parquet_file) + ds[0] # Load row group 0 + cached_table = ds._cached_row_group_table + ds[1] # Same row group + assert ds._cached_row_group_table is cached_table + + def test_cache_eviction_on_new_row_group(self, parquet_file): + ds = _make_dataset(parquet_file) + ds[0] # Row group 0 + old_table = ds._cached_row_group_table + ds[5] # Row group 1 + assert ds._cached_row_group_table is not old_table + + +class TestPackedParquetSchemaValidation: + """Tests for schema validation during _load_dataset.""" + + def test_missing_column_raises(self, tmp_path): + # Write a parquet file missing 'loss_mask' + table = pa.table( + { + "input_ids": [[1, 2, 3]], + "seq_start_id": [[0]], + } + ) + path = tmp_path / "bad.idx.parquet" + pq.write_table(table, str(path)) + + with pytest.raises(ValueError, match="missing required columns"): + _make_dataset(path) + + def test_empty_file_raises(self, tmp_path): + table = pa.table( + { + "input_ids": pa.array([], type=pa.list_(pa.int32())), + "loss_mask": pa.array([], type=pa.list_(pa.int8())), + "seq_start_id": pa.array([], type=pa.list_(pa.int32())), + } + ) + path = tmp_path / "empty.idx.parquet" + pq.write_table(table, str(path)) + + with pytest.raises(ValueError, match="empty"): + _make_dataset(path) + + +class TestValidateRow: + """Tests for the static validate_row method.""" + + def test_valid_row(self): + GPTSFTPackedParquetDataset.validate_row( + idx=0, + input_ids=[1, 2, 3, 4], + loss_mask=[1, 1, 0, 1], + seq_start_id=[0, 2], + ) + + def test_loss_mask_length_mismatch(self): + with pytest.raises(ValueError, match="loss_mask length"): + GPTSFTPackedParquetDataset.validate_row( + idx=0, + input_ids=[1, 2, 3], + loss_mask=[1, 1], + seq_start_id=[0], + ) + + def test_seq_start_id_not_starting_with_zero(self): + with pytest.raises(ValueError, match="must start with 0"): + GPTSFTPackedParquetDataset.validate_row( + idx=0, + input_ids=[1, 2, 3], + loss_mask=[1, 1, 1], + seq_start_id=[1], + ) + + def test_seq_start_id_empty(self): + with pytest.raises(ValueError, match="must start with 0"): + GPTSFTPackedParquetDataset.validate_row( + idx=0, + input_ids=[1, 2, 3], + loss_mask=[1, 1, 1], + seq_start_id=[], + ) + + def test_seq_start_id_out_of_bounds(self): + with pytest.raises(ValueError, match=">="): + GPTSFTPackedParquetDataset.validate_row( + idx=0, + input_ids=[1, 2, 3], + loss_mask=[1, 1, 1], + seq_start_id=[0, 5], + ) + + def test_seq_start_id_not_non_decreasing(self): + with pytest.raises(ValueError, match="not non-decreasing"): + GPTSFTPackedParquetDataset.validate_row( + idx=0, + input_ids=[1, 2, 3, 4], + loss_mask=[1, 1, 1, 1], + seq_start_id=[0, 3, 1], + ) + + +class TestPackedParquetClose: + """Tests for resource cleanup.""" + + def test_close_clears_state(self, tmp_path): + path = tmp_path / "data.idx.parquet" + _write_parquet(path, [_make_packed_row()]) + ds = _make_dataset(path) + ds[0] # Force reader open + assert len(ds._parquet_files) > 0 + + ds.close() + assert len(ds._parquet_files) == 0 + assert ds._cached_row_group_table is None + + def test_double_close_safe(self, tmp_path): + path = tmp_path / "data.idx.parquet" + _write_parquet(path, [_make_packed_row()]) + ds = _make_dataset(path) + ds.close() + ds.close() # Should not raise