Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT]: support optional rowgroups to read_parquet #2534

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,13 @@ class ParquetSourceConfig:

coerce_int96_timestamp_unit: PyTimeUnit | None
field_id_mapping: dict[int, PyField] | None
row_groups: list[list[int]] | None

def __init__(
self,
coerce_int96_timestamp_unit: PyTimeUnit | None = None,
field_id_mapping: dict[int, PyField] | None = None,
row_groups: list[list[int]] | None = None,
): ...

class CsvSourceConfig:
Expand Down
10 changes: 9 additions & 1 deletion daft/io/_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
@PublicAPI
def read_parquet(
path: Union[str, List[str]],
row_groups: Optional[List[List[int]]] = None,
schema_hints: Optional[Dict[str, DataType]] = None,
io_config: Optional["IOConfig"] = None,
use_native_downloader: bool = True,
Expand All @@ -33,9 +34,11 @@ def read_parquet(
>>> df = daft.read_parquet("/path/to/directory")
>>> df = daft.read_parquet("/path/to/files-*.parquet")
>>> df = daft.read_parquet("s3://path/to/files-*.parquet")
>>> df = daft.read_parquet("gs://path/to/files-*.parquet")

Args:
path (str): Path to Parquet file (allows for wildcards)
row_groups (List[int] or List[List[int]]): List of row groups to read corresponding to each file.
schema_hints (dict[str, DataType]): A mapping between column names and datatypes - passing this option
will override the specified columns on the inferred schema with the specified DataTypes
io_config (IOConfig): Config to be used with the native downloader
Expand Down Expand Up @@ -63,8 +66,13 @@ def read_parquet(

pytimeunit = coerce_int96_timestamp_unit._timeunit if coerce_int96_timestamp_unit is not None else None

if isinstance(path, list) and row_groups is not None and len(path) != len(row_groups):
raise ValueError("row_groups must be the same length as the list of paths provided.")
if isinstance(row_groups, list) and not isinstance(path, list):
raise ValueError("row_groups are only supported when reading multiple non-globbed/wildcarded files")

file_format_config = FileFormatConfig.from_parquet_config(
ParquetSourceConfig(coerce_int96_timestamp_unit=pytimeunit)
ParquetSourceConfig(coerce_int96_timestamp_unit=pytimeunit, row_groups=row_groups)
universalmind303 marked this conversation as resolved.
Show resolved Hide resolved
)
if use_native_downloader:
storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config))
Expand Down
6 changes: 6 additions & 0 deletions src/daft-micropartition/src/micropartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,12 @@ fn materialize_scan_task(
FileFormatConfig::Parquet(ParquetSourceConfig {
coerce_int96_timestamp_unit,
field_id_mapping,
..
}) => {
let inference_options =
ParquetSchemaInferenceOptions::new(Some(*coerce_int96_timestamp_unit));
let urls = urls.collect::<Vec<_>>();

let row_groups = parquet_sources_to_row_groups(scan_task.sources.as_slice());
let metadatas = scan_task
.sources
Expand Down Expand Up @@ -569,6 +571,7 @@ impl MicroPartition {
FileFormatConfig::Parquet(ParquetSourceConfig {
coerce_int96_timestamp_unit,
field_id_mapping,
..
}),
StorageConfig::Native(cfg),
) => {
Expand All @@ -589,6 +592,7 @@ impl MicroPartition {
.collect::<Option<Vec<_>>>();

let row_groups = parquet_sources_to_row_groups(scan_task.sources.as_slice());

read_parquet_into_micropartition(
uris.as_slice(),
columns.as_deref(),
Expand Down Expand Up @@ -1135,6 +1139,7 @@ pub(crate) fn read_parquet_into_micropartition(
.zip(metadata)
.zip(
row_groups
.clone()
.unwrap_or_else(|| std::iter::repeat(None).take(uris.len()).collect()),
)
.map(|((url, metadata), rgs)| DataFileSource::AnonymousDataFile {
Expand All @@ -1150,6 +1155,7 @@ pub(crate) fn read_parquet_into_micropartition(
FileFormatConfig::Parquet(ParquetSourceConfig {
coerce_int96_timestamp_unit: schema_infer_options.coerce_int96_timestamp_unit,
field_id_mapping,
row_groups,
})
.into(),
scan_task_daft_schema,
Expand Down
54 changes: 34 additions & 20 deletions src/daft-scan/src/anonymous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ use common_error::DaftResult;
use daft_core::schema::SchemaRef;

use crate::{
file_format::FileFormatConfig, storage_config::StorageConfig, DataFileSource, PartitionField,
Pushdowns, ScanOperator, ScanTask, ScanTaskRef,
file_format::{FileFormatConfig, ParquetSourceConfig},
storage_config::StorageConfig,
ChunkSpec, DataFileSource, PartitionField, Pushdowns, ScanOperator, ScanTask, ScanTaskRef,
};
#[derive(Debug)]
pub struct AnonymousScanOperator {
Expand Down Expand Up @@ -70,24 +71,37 @@ impl ScanOperator for AnonymousScanOperator {
let schema = self.schema.clone();
let storage_config = self.storage_config.clone();

let row_groups = if let FileFormatConfig::Parquet(ParquetSourceConfig {
row_groups: Some(row_groups),
..
}) = self.file_format_config.as_ref()
{
row_groups.clone()
} else {
std::iter::repeat(None).take(files.len()).collect()
};

// Create one ScanTask per file.
Ok(Box::new(files.into_iter().map(move |f| {
Ok(ScanTask::new(
vec![DataFileSource::AnonymousDataFile {
path: f.to_string(),
chunk_spec: None,
size_bytes: None,
metadata: None,
partition_spec: None,
statistics: None,
parquet_metadata: None,
}],
file_format_config.clone(),
schema.clone(),
storage_config.clone(),
pushdowns.clone(),
)
.into())
})))
Ok(Box::new(files.into_iter().zip(row_groups).map(
move |(f, rg)| {
let chunk_spec = rg.map(ChunkSpec::Parquet);
Ok(ScanTask::new(
vec![DataFileSource::AnonymousDataFile {
path: f.to_string(),
chunk_spec,
size_bytes: None,
metadata: None,
partition_spec: None,
statistics: None,
parquet_metadata: None,
}],
file_format_config.clone(),
schema.clone(),
storage_config.clone(),
pushdowns.clone(),
)
.into())
},
)))
}
}
22 changes: 22 additions & 0 deletions src/daft-scan/src/file_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pub struct ParquetSourceConfig {
///
/// See: https://github.com/apache/parquet-format/blob/master/src/main/thrift/parquet.thrift#L456-L459
pub field_id_mapping: Option<Arc<BTreeMap<i32, Field>>>,
pub row_groups: Option<Vec<Option<Vec<i64>>>>,
}

impl ParquetSourceConfig {
Expand All @@ -140,6 +141,25 @@ impl ParquetSourceConfig {
.join(",")
));
}
if let Some(row_groups) = &self.row_groups {
res.push(format!(
"Row Groups = {{{}}}",
row_groups
.iter()
.map(|rg| {
rg.as_ref()
.map(|rg| {
rg.iter()
.map(|i| i.to_string())
.collect::<Vec<String>>()
.join(",")
})
.unwrap_or_else(|| "None".to_string())
})
.collect::<Vec<String>>()
.join(",")
));
}
res
}
}
Expand All @@ -152,6 +172,7 @@ impl ParquetSourceConfig {
fn new(
coerce_int96_timestamp_unit: Option<PyTimeUnit>,
field_id_mapping: Option<BTreeMap<i32, PyField>>,
row_groups: Option<Vec<Option<Vec<i64>>>>,
) -> Self {
Self {
coerce_int96_timestamp_unit: coerce_int96_timestamp_unit
Expand All @@ -162,6 +183,7 @@ impl ParquetSourceConfig {
map.into_iter().map(|(k, v)| (k, v.field)),
))
}),
row_groups,
}
}

Expand Down
24 changes: 20 additions & 4 deletions src/daft-scan/src/glob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use snafu::Snafu;
use crate::{
file_format::{CsvSourceConfig, FileFormatConfig, ParquetSourceConfig},
storage_config::StorageConfig,
DataFileSource, PartitionField, Pushdowns, ScanOperator, ScanTask, ScanTaskRef,
ChunkSpec, DataFileSource, PartitionField, Pushdowns, ScanOperator, ScanTask, ScanTaskRef,
};
#[derive(Debug)]
pub struct GlobScanOperator {
Expand Down Expand Up @@ -167,6 +167,7 @@ impl GlobScanOperator {
FileFormatConfig::Parquet(ParquetSourceConfig {
coerce_int96_timestamp_unit,
field_id_mapping,
..
}) => {
let io_stats = IOStatsContext::new(format!(
"GlobScanOperator constructor read_parquet_schema: for uri {first_filepath}"
Expand Down Expand Up @@ -299,8 +300,19 @@ impl ScanOperator for GlobScanOperator {
let schema = self.schema.clone();
let storage_config = self.storage_config.clone();
let is_ray_runner = self.is_ray_runner;

let row_groups = if let FileFormatConfig::Parquet(ParquetSourceConfig {
row_groups: Some(row_groups),
..
}) = self.file_format_config.as_ref()
{
Some(row_groups.clone())
} else {
None
};

// Create one ScanTask per file
Ok(Box::new(files.map(move |f| {
Ok(Box::new(files.enumerate().map(move |(idx, f)| {
let FileMetadata {
filepath: path,
size: size_bytes,
Expand Down Expand Up @@ -337,11 +349,15 @@ impl ScanOperator for GlobScanOperator {
} else {
None
};

let row_group = row_groups
.as_ref()
.and_then(|rgs| rgs.get(idx).cloned())
.flatten();
let chunk_spec = row_group.map(ChunkSpec::Parquet);
Ok(ScanTask::new(
vec![DataFileSource::AnonymousDataFile {
path: path.to_string(),
chunk_spec: None,
chunk_spec,
size_bytes,
metadata: None,
partition_spec: None,
Expand Down
9 changes: 9 additions & 0 deletions tests/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,12 @@ def test_parquet_read_int96_timestamps_schema_inference(coerce_to, store_schema)
) as f:
schema = daft.read_parquet(f, coerce_int96_timestamp_unit=coerce_to).schema()
assert schema == expected, f"Expected:\n{expected}\n\nReceived:\n{schema}"


def test_row_groups():
path = ["tests/assets/parquet-data/mvp.parquet"]

df = daft.read_parquet(path).collect()
assert df.count_rows() == 100
df = daft.read_parquet(path, row_groups=[[0, 1]]).collect()
assert df.count_rows() == 20
Loading