Skip to content

Commit

Permalink
[FEAT]: support optional rowgroups to read_parquet (#2534)
Browse files Browse the repository at this point in the history
closes #2500
  • Loading branch information
universalmind303 authored Jul 19, 2024
1 parent 63518de commit 9cd1482
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 25 deletions.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,13 @@ class ParquetSourceConfig:

coerce_int96_timestamp_unit: PyTimeUnit | None
field_id_mapping: dict[int, PyField] | None
row_groups: list[list[int]] | None

def __init__(
self,
coerce_int96_timestamp_unit: PyTimeUnit | None = None,
field_id_mapping: dict[int, PyField] | None = None,
row_groups: list[list[int]] | None = None,
): ...

class CsvSourceConfig:
Expand Down
10 changes: 9 additions & 1 deletion daft/io/_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
@PublicAPI
def read_parquet(
path: Union[str, List[str]],
row_groups: Optional[List[List[int]]] = None,
schema_hints: Optional[Dict[str, DataType]] = None,
io_config: Optional["IOConfig"] = None,
use_native_downloader: bool = True,
Expand All @@ -33,9 +34,11 @@ def read_parquet(
>>> df = daft.read_parquet("/path/to/directory")
>>> df = daft.read_parquet("/path/to/files-*.parquet")
>>> df = daft.read_parquet("s3://path/to/files-*.parquet")
>>> df = daft.read_parquet("gs://path/to/files-*.parquet")
Args:
path (str): Path to Parquet file (allows for wildcards)
row_groups (List[int] or List[List[int]]): List of row groups to read corresponding to each file.
schema_hints (dict[str, DataType]): A mapping between column names and datatypes - passing this option
will override the specified columns on the inferred schema with the specified DataTypes
io_config (IOConfig): Config to be used with the native downloader
Expand Down Expand Up @@ -63,8 +66,13 @@ def read_parquet(

pytimeunit = coerce_int96_timestamp_unit._timeunit if coerce_int96_timestamp_unit is not None else None

if isinstance(path, list) and row_groups is not None and len(path) != len(row_groups):
raise ValueError("row_groups must be the same length as the list of paths provided.")
if isinstance(row_groups, list) and not isinstance(path, list):
raise ValueError("row_groups are only supported when reading multiple non-globbed/wildcarded files")

file_format_config = FileFormatConfig.from_parquet_config(
ParquetSourceConfig(coerce_int96_timestamp_unit=pytimeunit)
ParquetSourceConfig(coerce_int96_timestamp_unit=pytimeunit, row_groups=row_groups)
)
if use_native_downloader:
storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config))
Expand Down
6 changes: 6 additions & 0 deletions src/daft-micropartition/src/micropartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,12 @@ fn materialize_scan_task(
FileFormatConfig::Parquet(ParquetSourceConfig {
coerce_int96_timestamp_unit,
field_id_mapping,
..
}) => {
let inference_options =
ParquetSchemaInferenceOptions::new(Some(*coerce_int96_timestamp_unit));
let urls = urls.collect::<Vec<_>>();

let row_groups = parquet_sources_to_row_groups(scan_task.sources.as_slice());
let metadatas = scan_task
.sources
Expand Down Expand Up @@ -569,6 +571,7 @@ impl MicroPartition {
FileFormatConfig::Parquet(ParquetSourceConfig {
coerce_int96_timestamp_unit,
field_id_mapping,
..
}),
StorageConfig::Native(cfg),
) => {
Expand All @@ -589,6 +592,7 @@ impl MicroPartition {
.collect::<Option<Vec<_>>>();

let row_groups = parquet_sources_to_row_groups(scan_task.sources.as_slice());

read_parquet_into_micropartition(
uris.as_slice(),
columns.as_deref(),
Expand Down Expand Up @@ -1135,6 +1139,7 @@ pub(crate) fn read_parquet_into_micropartition(
.zip(metadata)
.zip(
row_groups
.clone()
.unwrap_or_else(|| std::iter::repeat(None).take(uris.len()).collect()),
)
.map(|((url, metadata), rgs)| DataFileSource::AnonymousDataFile {
Expand All @@ -1150,6 +1155,7 @@ pub(crate) fn read_parquet_into_micropartition(
FileFormatConfig::Parquet(ParquetSourceConfig {
coerce_int96_timestamp_unit: schema_infer_options.coerce_int96_timestamp_unit,
field_id_mapping,
row_groups,
})
.into(),
scan_task_daft_schema,
Expand Down
54 changes: 34 additions & 20 deletions src/daft-scan/src/anonymous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ use common_error::DaftResult;
use daft_core::schema::SchemaRef;

use crate::{
file_format::FileFormatConfig, storage_config::StorageConfig, DataFileSource, PartitionField,
Pushdowns, ScanOperator, ScanTask, ScanTaskRef,
file_format::{FileFormatConfig, ParquetSourceConfig},
storage_config::StorageConfig,
ChunkSpec, DataFileSource, PartitionField, Pushdowns, ScanOperator, ScanTask, ScanTaskRef,
};
#[derive(Debug)]
pub struct AnonymousScanOperator {
Expand Down Expand Up @@ -70,24 +71,37 @@ impl ScanOperator for AnonymousScanOperator {
let schema = self.schema.clone();
let storage_config = self.storage_config.clone();

let row_groups = if let FileFormatConfig::Parquet(ParquetSourceConfig {
row_groups: Some(row_groups),
..
}) = self.file_format_config.as_ref()
{
row_groups.clone()
} else {
std::iter::repeat(None).take(files.len()).collect()
};

// Create one ScanTask per file.
Ok(Box::new(files.into_iter().map(move |f| {
Ok(ScanTask::new(
vec![DataFileSource::AnonymousDataFile {
path: f.to_string(),
chunk_spec: None,
size_bytes: None,
metadata: None,
partition_spec: None,
statistics: None,
parquet_metadata: None,
}],
file_format_config.clone(),
schema.clone(),
storage_config.clone(),
pushdowns.clone(),
)
.into())
})))
Ok(Box::new(files.into_iter().zip(row_groups).map(
move |(f, rg)| {
let chunk_spec = rg.map(ChunkSpec::Parquet);
Ok(ScanTask::new(
vec![DataFileSource::AnonymousDataFile {
path: f.to_string(),
chunk_spec,
size_bytes: None,
metadata: None,
partition_spec: None,
statistics: None,
parquet_metadata: None,
}],
file_format_config.clone(),
schema.clone(),
storage_config.clone(),
pushdowns.clone(),
)
.into())
},
)))
}
}
22 changes: 22 additions & 0 deletions src/daft-scan/src/file_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pub struct ParquetSourceConfig {
///
/// See: https://github.com/apache/parquet-format/blob/master/src/main/thrift/parquet.thrift#L456-L459
pub field_id_mapping: Option<Arc<BTreeMap<i32, Field>>>,
pub row_groups: Option<Vec<Option<Vec<i64>>>>,
}

impl ParquetSourceConfig {
Expand All @@ -140,6 +141,25 @@ impl ParquetSourceConfig {
.join(",")
));
}
if let Some(row_groups) = &self.row_groups {
res.push(format!(
"Row Groups = {{{}}}",
row_groups
.iter()
.map(|rg| {
rg.as_ref()
.map(|rg| {
rg.iter()
.map(|i| i.to_string())
.collect::<Vec<String>>()
.join(",")
})
.unwrap_or_else(|| "None".to_string())
})
.collect::<Vec<String>>()
.join(",")
));
}
res
}
}
Expand All @@ -152,6 +172,7 @@ impl ParquetSourceConfig {
fn new(
coerce_int96_timestamp_unit: Option<PyTimeUnit>,
field_id_mapping: Option<BTreeMap<i32, PyField>>,
row_groups: Option<Vec<Option<Vec<i64>>>>,
) -> Self {
Self {
coerce_int96_timestamp_unit: coerce_int96_timestamp_unit
Expand All @@ -162,6 +183,7 @@ impl ParquetSourceConfig {
map.into_iter().map(|(k, v)| (k, v.field)),
))
}),
row_groups,
}
}

Expand Down
24 changes: 20 additions & 4 deletions src/daft-scan/src/glob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use snafu::Snafu;
use crate::{
file_format::{CsvSourceConfig, FileFormatConfig, ParquetSourceConfig},
storage_config::StorageConfig,
DataFileSource, PartitionField, Pushdowns, ScanOperator, ScanTask, ScanTaskRef,
ChunkSpec, DataFileSource, PartitionField, Pushdowns, ScanOperator, ScanTask, ScanTaskRef,
};
#[derive(Debug)]
pub struct GlobScanOperator {
Expand Down Expand Up @@ -167,6 +167,7 @@ impl GlobScanOperator {
FileFormatConfig::Parquet(ParquetSourceConfig {
coerce_int96_timestamp_unit,
field_id_mapping,
..
}) => {
let io_stats = IOStatsContext::new(format!(
"GlobScanOperator constructor read_parquet_schema: for uri {first_filepath}"
Expand Down Expand Up @@ -299,8 +300,19 @@ impl ScanOperator for GlobScanOperator {
let schema = self.schema.clone();
let storage_config = self.storage_config.clone();
let is_ray_runner = self.is_ray_runner;

let row_groups = if let FileFormatConfig::Parquet(ParquetSourceConfig {
row_groups: Some(row_groups),
..
}) = self.file_format_config.as_ref()
{
Some(row_groups.clone())
} else {
None
};

// Create one ScanTask per file
Ok(Box::new(files.map(move |f| {
Ok(Box::new(files.enumerate().map(move |(idx, f)| {
let FileMetadata {
filepath: path,
size: size_bytes,
Expand Down Expand Up @@ -337,11 +349,15 @@ impl ScanOperator for GlobScanOperator {
} else {
None
};

let row_group = row_groups
.as_ref()
.and_then(|rgs| rgs.get(idx).cloned())
.flatten();
let chunk_spec = row_group.map(ChunkSpec::Parquet);
Ok(ScanTask::new(
vec![DataFileSource::AnonymousDataFile {
path: path.to_string(),
chunk_spec: None,
chunk_spec,
size_bytes,
metadata: None,
partition_spec: None,
Expand Down
9 changes: 9 additions & 0 deletions tests/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,12 @@ def test_parquet_read_int96_timestamps_schema_inference(coerce_to, store_schema)
) as f:
schema = daft.read_parquet(f, coerce_int96_timestamp_unit=coerce_to).schema()
assert schema == expected, f"Expected:\n{expected}\n\nReceived:\n{schema}"


def test_row_groups():
path = ["tests/assets/parquet-data/mvp.parquet"]

df = daft.read_parquet(path).collect()
assert df.count_rows() == 100
df = daft.read_parquet(path, row_groups=[[0, 1]]).collect()
assert df.count_rows() == 20

0 comments on commit 9cd1482

Please sign in to comment.