Skip to content

Commit

Permalink
Expose register_listing_table
Browse files Browse the repository at this point in the history
This lets users nicely use `object_store` with python
datafusion for partitioned dataset e.g. in S3.

Closes #617
  • Loading branch information
Henri Froese committed Mar 24, 2024
1 parent 7e3c0e1 commit a43684a
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 2 deletions.
58 changes: 57 additions & 1 deletion datafusion/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
import pyarrow as pa
import pyarrow.dataset as ds
import pytest
from datafusion.object_store import LocalFileSystem

from datafusion import udf
from datafusion import udf, col

from . import generic as helpers

Expand Down Expand Up @@ -374,3 +375,58 @@ def test_simple_select(ctx, tmp_path, arr):
result = batches[0].column(0)

np.testing.assert_equal(result, arr)


@pytest.mark.parametrize("file_sort_order", (None, [[col("int").sort(True, True)]]))
@pytest.mark.parametrize("pass_schema", (True, False))
def test_register_listing_table(ctx, tmp_path, pass_schema, file_sort_order):
dir_root = tmp_path / "dataset_parquet_partitioned"
dir_root.mkdir(exist_ok=False)
(dir_root / "grp=a/date_id=20201005").mkdir(exist_ok=False, parents=True)
(dir_root / "grp=a/date_id=20211005").mkdir(exist_ok=False, parents=True)
(dir_root / "grp=b/date_id=20201005").mkdir(exist_ok=False, parents=True)

table = pa.Table.from_arrays(
[
[1, 2, 3, 4, 5, 6, 7],
["a", "b", "c", "d", "e", "f", "g"],
[1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7],
],
names=["int", "str", "float"],
)
pa.parquet.write_table(
table.slice(0, 3), dir_root / "grp=a/date_id=20201005/file.parquet"
)
pa.parquet.write_table(
table.slice(3, 2), dir_root / "grp=a/date_id=20211005/file.parquet"
)
pa.parquet.write_table(
table.slice(5, 10), dir_root / "grp=b/date_id=20201005/file.parquet"
)

ctx.register_object_store("file://local", LocalFileSystem(), None)
ctx.register_listing_table(
"my_table",
f"file://{dir_root}/",
table_partition_cols=[("grp", "string"), ("date_id", "int")],
file_extension=".parquet",
schema=table.schema if pass_schema else None,
file_sort_order=file_sort_order,
)
assert ctx.tables() == {"my_table"}

result = ctx.sql(
"SELECT grp, COUNT(*) AS count FROM my_table GROUP BY grp"
).collect()
result = pa.Table.from_batches(result)

rd = result.to_pydict()
assert dict(zip(rd["grp"], rd["count"])) == {"a": 5, "b": 2}

result = ctx.sql(
"SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005 GROUP BY grp"
).collect()
result = pa.Table.from_batches(result)

rd = result.to_pydict()
assert dict(zip(rd["grp"], rd["count"])) == {"a": 3, "b": 2}
53 changes: 52 additions & 1 deletion src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@ use crate::store::StorageContexts;
use crate::udaf::PyAggregateUDF;
use crate::udf::PyScalarUDF;
use crate::utils::{get_tokio_runtime, wait_for_future};
use datafusion::arrow::datatypes::{DataType, Schema};
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
use datafusion::datasource::file_format::parquet::ParquetFormat;
use datafusion::datasource::listing::{
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
};
use datafusion::datasource::MemTable;
use datafusion::datasource::TableProvider;
use datafusion::execution::context::{SessionConfig, SessionContext, SessionState, TaskContext};
Expand Down Expand Up @@ -278,6 +282,53 @@ impl PySessionContext {
Ok(())
}

#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (name, path, table_partition_cols=vec![],
file_extension=".parquet",
schema=None,
file_sort_order=None))]
pub fn register_listing_table(
&mut self,
name: &str,
path: &str,
table_partition_cols: Vec<(String, String)>,
file_extension: &str,
schema: Option<PyArrowType<Schema>>,
file_sort_order: Option<Vec<Vec<PyExpr>>>,
py: Python,
) -> PyResult<()> {
let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
.with_file_extension(file_extension)
.with_table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.with_file_sort_order(
file_sort_order
.unwrap_or_default()
.into_iter()
.map(|e| e.into_iter().map(|f| f.into()).collect())
.collect(),
);
let table_path = ListingTableUrl::parse(path)?;
let resolved_schema: SchemaRef = match schema {
Some(s) => Arc::new(s.0),
None => {
let state = self.ctx.state();
let schema = options.infer_schema(&state, &table_path);
wait_for_future(py, schema).map_err(DataFusionError::from)?
}
};
let config = ListingTableConfig::new(table_path)
.with_listing_options(options)
.with_schema(resolved_schema);
let table = ListingTable::try_new(config)?;
self.register_table(
name,
&PyTable {
table: Arc::new(table),
},
)?;
Ok(())
}

/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
pub fn sql(&mut self, query: &str, py: Python) -> PyResult<PyDataFrame> {
let result = self.ctx.sql(query);
Expand Down

0 comments on commit a43684a

Please sign in to comment.