Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 160 additions & 13 deletions src/datasets/packaged_modules/parquet/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,56 @@
logger = datasets.utils.logging.get_logger(__name__)


def _is_nested_type(pa_type):
"""Check if a PyArrow type contains nested structures."""
return (
pa.types.is_list(pa_type)
or pa.types.is_large_list(pa_type)
or pa.types.is_struct(pa_type)
or pa.types.is_map(pa_type)
or pa.types.is_union(pa_type)
)


def _handle_nested_chunked_conversion(pa_table):
"""Handle PyArrow nested data conversion issues by combining chunks selectively."""
try:
# Check if any columns have multiple chunks with nested data
needs_combining = False
for column_name in pa_table.column_names:
column = pa_table.column(column_name)
if isinstance(column, pa.ChunkedArray) and column.num_chunks > 1:
# Check if column contains nested types
if _is_nested_type(column.type):
needs_combining = True
break

if needs_combining:
# Combine chunks only for problematic columns to minimize memory impact
combined_columns = {}
for column_name in pa_table.column_names:
column = pa_table.column(column_name)
if (
isinstance(column, pa.ChunkedArray)
and column.num_chunks > 1
and _is_nested_type(column.type)
):
combined_columns[column_name] = column.combine_chunks()
else:
combined_columns[column_name] = column

return pa.table(combined_columns)

return pa_table

except Exception as e:
# Fallback: combine all chunks if selective approach fails
logger.warning(
f"Selective chunk combining failed, using full combine_chunks(): {e}"
)
return pa_table.combine_chunks()


@dataclass
class ParquetConfig(datasets.BuilderConfig):
"""BuilderConfig for Parquet."""
Expand Down Expand Up @@ -44,7 +94,9 @@ def _info(self):
def _split_generators(self, dl_manager):
"""We handle string, list and dicts in datafiles"""
if not self.config.data_files:
raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}")
raise ValueError(
f"At least one data file must be specified, but got data_files={self.config.data_files}"
)
dl_manager.download_config.extract_on_the_fly = True
data_files = dl_manager.download_and_extract(self.config.data_files)
splits = []
Expand All @@ -57,12 +109,22 @@ def _split_generators(self, dl_manager):
if self.info.features is None:
for file in itertools.chain.from_iterable(files):
with open(file, "rb") as f:
self.info.features = datasets.Features.from_arrow_schema(pq.read_schema(f))
self.info.features = datasets.Features.from_arrow_schema(
pq.read_schema(f)
)
break
splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files}))
if self.config.columns is not None and set(self.config.columns) != set(self.info.features):
splits.append(
datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})
)
if self.config.columns is not None and set(self.config.columns) != set(
self.info.features
):
self.info.features = datasets.Features(
{col: feat for col, feat in self.info.features.items() if col in self.config.columns}
{
col: feat
for col, feat in self.info.features.items()
if col in self.config.columns
}
)
return splits

Expand All @@ -75,7 +137,9 @@ def _cast_table(self, pa_table: pa.Table) -> pa.Table:

def _generate_tables(self, files):
if self.config.features is not None and self.config.columns is not None:
if sorted(field.name for field in self.info.features.arrow_schema) != sorted(self.config.columns):
if sorted(field.name for field in self.info.features.arrow_schema) != sorted(
self.config.columns
):
raise ValueError(
f"Tried to load parquet data with columns '{self.config.columns}' with mismatching features '{self.info.features}'"
)
Expand All @@ -88,7 +152,9 @@ def _generate_tables(self, files):
with open(file, "rb") as f:
parquet_fragment = ds.ParquetFileFormat().make_fragment(f)
if parquet_fragment.row_groups:
batch_size = self.config.batch_size or parquet_fragment.row_groups[0].num_rows
batch_size = (
self.config.batch_size or parquet_fragment.row_groups[0].num_rows
)
try:
for batch_idx, record_batch in enumerate(
parquet_fragment.to_batches(
Expand All @@ -100,10 +166,91 @@ def _generate_tables(self, files):
)
):
pa_table = pa.Table.from_batches([record_batch])
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
# logger.warning('\n'.join(str(pa_table.slice(i, 1).to_pydict()) for i in range(pa_table.num_rows)))
yield f"{file_idx}_{batch_idx}", self._cast_table(pa_table)
except ValueError as e:
logger.error(f"Failed to read file '{file}' with error {type(e)}: {e}")
raise
except pa.ArrowNotImplementedError as e:
if (
"Nested data conversions not implemented for chunked array outputs"
in str(e)
):
# Fallback for nested data: bypass fragment reading entirely
logger.warning(
f"Using fallback for nested data in file '{file}': {e}"
)
try:
# Reset file pointer and use direct parquet file reading
f.seek(0)
parquet_file = pq.ParquetFile(f)

# Read row groups one by one to avoid chunking issues
tables = []
for row_group_idx in range(
parquet_file.num_row_groups
):
try:
# Read single row group
rg_table = parquet_file.read_row_group(
row_group_idx,
columns=self.config.columns,
use_pandas_metadata=False,
)

# Apply filter if needed
if filter_expr is not None:
rg_table = rg_table.filter(filter_expr)

# Immediately combine chunks
if rg_table.num_rows > 0:
rg_table = (
_handle_nested_chunked_conversion(
rg_table
)
)
tables.append(rg_table)

except pa.ArrowNotImplementedError as rg_error:
if (
"Nested data conversions not implemented"
in str(rg_error)
):
logger.warning(
f"Skipping row group {row_group_idx} due to nested data issues: {rg_error}"
)
continue
else:
raise

if not tables:
logger.error(
f"Could not read any row groups from file '{file}'"
)
continue

# Combine all readable row groups
full_table = (
pa.concat_tables(tables)
if len(tables) > 1
else tables[0]
)

# Split into batches manually
for batch_idx in range(
0, full_table.num_rows, batch_size
):
end_idx = min(
batch_idx + batch_size, full_table.num_rows
)
batch_table = full_table.slice(
batch_idx, end_idx - batch_idx
)
yield f"{file_idx}_{batch_idx // batch_size}", (
self._cast_table(batch_table)
)

except Exception as fallback_error:
logger.error(
f"Fallback approach also failed for file '{file}': {fallback_error}"
)
raise
else:
# Re-raise if it's a different Arrow error
raise