Skip to content

Commit

Permalink
feat(pyspark): add read_csv, read_parquet, and register
Browse files Browse the repository at this point in the history
  • Loading branch information
gforsyth authored and cpcloud committed Mar 23, 2023
1 parent d6235ee commit 7bd22af
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 10 deletions.
5 changes: 2 additions & 3 deletions ibis/backends/duckdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,8 @@ def register(
parquet/csv files, an iterable of parquet or CSV files, a pandas
dataframe, a pyarrow table or dataset, or a postgres URI.
table_name
An optional name to use for the created table. This defaults to the
filename if a path (with hyphens replaced with underscores), or
sequentially generated name otherwise.
An optional name to use for the created table. This defaults to a
sequentially generated name.
**kwargs
Additional keyword arguments passed to DuckDB loading functions for
CSV or parquet. See https://duckdb.org/docs/data/csv and
Expand Down
130 changes: 130 additions & 0 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
from pathlib import Path
from typing import TYPE_CHECKING, Any

Expand All @@ -13,6 +14,7 @@
import ibis.expr.operations as ops
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis import util
from ibis.backends.base.df.scope import Scope
from ibis.backends.base.df.timecontext import canonicalize_context, localize_context
from ibis.backends.base.sql import BaseSQLBackend
Expand All @@ -37,6 +39,16 @@
'escape': '"',
}

pa_n = itertools.count(0)
csv_n = itertools.count(0)


def normalize_filenames(source_list):
# Promote to list
source_list = util.promote_list(source_list)

return list(map(util.normalize_filename, source_list))


class _PySparkCursor:
"""Spark cursor.
Expand Down Expand Up @@ -574,3 +586,121 @@ def _clean_up_cached_table(self, op):
assert t.is_cached
t.unpersist()
assert not t.is_cached

def read_parquet(
self,
source: str | Path,
table_name: str | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a parquet file as a table in the current database.
Parameters
----------
source
The data source. May be a path to a file or directory of parquet files.
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
kwargs
Additional keyword arguments passed to PySpark.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.parquet.html
Returns
-------
ir.Table
The just-registered table
"""
source = util.normalize_filename(source)
spark_df = self._session.read.parquet(source, **kwargs)
table_name = table_name or f"ibis_read_parquet_{next(pa_n)}"

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)

def read_csv(
self,
source_list: str | list[str] | tuple[str],
table_name: str | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a CSV file as a table in the current database.
Parameters
----------
source_list
The data source(s). May be a path to a file or directory of CSV files, or an
iterable of CSV files.
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
kwargs
Additional keyword arguments passed to PySpark loading function.
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.csv.html
Returns
-------
ir.Table
The just-registered table
"""
source_list = normalize_filenames(source_list)
spark_df = self._session.read.csv(source_list, **kwargs)
table_name = table_name or f"ibis_read_csv_{next(csv_n)}"

spark_df.createOrReplaceTempView(table_name)
return self.table(table_name)

def register(
self,
source: str | Path | Any,
table_name: str | None = None,
**kwargs: Any,
) -> ir.Table:
"""Register a data source as a table in the current database.
Parameters
----------
source
The data source(s). May be a path to a file or directory of
parquet/csv files, or an iterable of CSV files.
table_name
An optional name to use for the created table. This defaults to
a sequentially generated name.
**kwargs
Additional keyword arguments passed to PySpark loading functions for
CSV or parquet.
Returns
-------
ir.Table
The just-registered table
"""

if isinstance(source, (str, Path)):
first = str(source)
elif isinstance(source, (list, tuple)):
first = source[0]
else:
self._register_failure()

if first.startswith(("parquet://", "parq://")) or first.endswith(
("parq", "parquet")
):
return self.read_parquet(source, table_name=table_name, **kwargs)
elif first.startswith(
("csv://", "csv.gz://", "txt://", "txt.gz://")
) or first.endswith(("csv", "csv.gz", "tsv", "tsv.gz", "txt", "txt.gz")):
return self.read_csv(source, table_name=table_name, **kwargs)
else:
self._register_failure() # noqa: RET503

def _register_failure(self):
import inspect

msg = ", ".join(
name for name, _ in inspect.getmembers(self) if name.startswith("read_")
)
raise ValueError(
f"Cannot infer appropriate read function for input, "
f"please call one of {msg} directly"
)
108 changes: 101 additions & 7 deletions ibis/backends/tests/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@ def gzip_csv(data_directory, tmp_path):
("fname", "in_table_name", "out_table_name"),
[
param("diamonds.csv", None, "ibis_read_csv_", id="default"),
param("csv://diamonds.csv", "Diamonds2", "Diamonds2", id="csv_name"),
param(
"csv://diamonds.csv",
"Diamonds2",
"Diamonds2",
id="csv_name",
marks=pytest.mark.notyet(
["pyspark"], reason="pyspark lowercases view names"
),
),
param(
"file://diamonds.csv",
"fancy_stones",
Expand All @@ -53,11 +61,14 @@ def gzip_csv(data_directory, tmp_path):
"fancy stones",
"fancy stones",
id="file_atypical_name",
marks=pytest.mark.notyet(
["pyspark"], reason="no spaces allowed in view names"
),
),
param(
["file://diamonds.csv", "diamonds.csv"],
"fancy stones",
"fancy stones",
"fancy_stones2",
"fancy_stones2",
id="multi_csv",
marks=pytest.mark.notyet(
["polars", "datafusion"],
Expand All @@ -76,7 +87,6 @@ def gzip_csv(data_directory, tmp_path):
"mysql",
"pandas",
"postgres",
"pyspark",
"snowflake",
"sqlite",
"trino",
Expand All @@ -102,7 +112,6 @@ def test_register_csv(con, data_directory, fname, in_table_name, out_table_name)
"mysql",
"pandas",
"postgres",
"pyspark",
"snowflake",
"sqlite",
"trino",
Expand All @@ -125,7 +134,6 @@ def test_register_csv_gz(con, data_directory, gzip_csv):
"mysql",
"pandas",
"postgres",
"pyspark",
"snowflake",
"sqlite",
"trino",
Expand Down Expand Up @@ -179,7 +187,6 @@ def read_table(path: Path) -> Iterator[tuple[str, pa.Table]]:
"mysql",
"pandas",
"postgres",
"pyspark",
"snowflake",
"sqlite",
"trino",
Expand Down Expand Up @@ -381,3 +388,90 @@ def test_register_garbage(con, monkeypatch):

with pytest.raises(FileNotFoundError):
con.read_parquet("garbage_notafile")


@pytest.mark.parametrize(
("fname", "in_table_name", "out_table_name"),
[
(
"functional_alltypes.parquet",
None,
"ibis_read_parquet",
),
("functional_alltypes.parquet", "funk_all", "funk_all"),
],
)
@pytest.mark.notyet(
[
"bigquery",
"clickhouse",
"dask",
"impala",
"mssql",
"mysql",
"pandas",
"postgres",
"snowflake",
"sqlite",
"trino",
]
)
def test_read_parquet(
con, tmp_path, data_directory, fname, in_table_name, out_table_name
):
pq = pytest.importorskip("pyarrow.parquet")

fname = Path(fname)
table = read_table(data_directory / fname.name)

pq.write_table(table, tmp_path / fname.name)

with pushd(data_directory):
if con.name == "pyspark":
# pyspark doesn't respect CWD
fname = str(Path(fname).absolute())
table = con.read_parquet(fname, table_name=in_table_name)

assert any(t.startswith(out_table_name) for t in con.list_tables())

if con.name != "datafusion":
table.count().execute()


@pytest.mark.parametrize(
("fname", "in_table_name", "out_table_name"),
[
param("diamonds.csv", None, "ibis_read_csv_", id="default"),
param(
"diamonds.csv",
"fancy_stones",
"fancy_stones",
id="file_name",
),
],
)
@pytest.mark.notyet(
[
"bigquery",
"clickhouse",
"dask",
"impala",
"mssql",
"mysql",
"pandas",
"postgres",
"snowflake",
"sqlite",
"trino",
]
)
def test_read_csv(con, data_directory, fname, in_table_name, out_table_name):
with pushd(data_directory):
if con.name == "pyspark":
# pyspark doesn't respect CWD
fname = str(Path(fname).absolute())
table = con.read_csv(fname, table_name=in_table_name)

assert any(t.startswith(out_table_name) for t in con.list_tables())
if con.name != "datafusion":
table.count().execute()

0 comments on commit 7bd22af

Please sign in to comment.