diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 023fbeabcbabc..1027918adbe15 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -29,7 +29,7 @@ from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.types import _parse_datatype_json_string from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column -from pyspark.sql.readwriter import DataFrameWriter +from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2 from pyspark.sql.streaming import DataStreamWriter from pyspark.sql.types import * from pyspark.sql.pandas.conversion import PandasConversionMixin @@ -2240,6 +2240,22 @@ def inputFiles(self): sinceversion=1.4, doc=":func:`drop_duplicates` is an alias for :func:`dropDuplicates`.") + @since(3.1) + def writeTo(self, table): + """ + Create a write configuration builder for v2 sources. + + This builder is used to configure and execute write operations. + + For example, to append or create or replace existing tables. + + >>> df.writeTo("catalog.db.table").append() # doctest: +SKIP + >>> df.writeTo( # doctest: +SKIP + ... "catalog.db.table" + ... ).partitionedBy("col").createOrReplace() + """ + return DataFrameWriterV2(self, table) + def _to_scala_map(sc, jm): """ diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5a352104c4eca..3ca4edafa6873 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -3322,6 +3322,118 @@ def map_zip_with(col1, col2, f): return _invoke_higher_order_function("MapZipWith", [col1, col2], [f]) +# ---------------------- Partition transform functions -------------------------------- + +@since(3.1) +def years(col): + """ + Partition transform function: A transform for timestamps and dates + to partition data into years. + + >>> df.writeTo("catalog.db.table").partitionedBy( # doctest: +SKIP + ... years("ts") + ... ).createOrReplace() + + .. warning:: + This function can be used only in combinatiion with + :py:meth:`~pyspark.sql.readwriter.DataFrameWriterV2.partitionedBy` + method of the `DataFrameWriterV2`. + + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.years(_to_java_column(col))) + + +@since(3.1) +def months(col): + """ + Partition transform function: A transform for timestamps and dates + to partition data into months. + + >>> df.writeTo("catalog.db.table").partitionedBy( + ... months("ts") + ... ).createOrReplace() # doctest: +SKIP + + .. warning:: + This function can be used only in combinatiion with + :py:meth:`~pyspark.sql.readwriter.DataFrameWriterV2.partitionedBy` + method of the `DataFrameWriterV2`. + + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.months(_to_java_column(col))) + + +@since(3.1) +def days(col): + """ + Partition transform function: A transform for timestamps and dates + to partition data into days. + + >>> df.writeTo("catalog.db.table").partitionedBy( # doctest: +SKIP + ... days("ts") + ... ).createOrReplace() + + .. warning:: + This function can be used only in combinatiion with + :py:meth:`~pyspark.sql.readwriter.DataFrameWriterV2.partitionedBy` + method of the `DataFrameWriterV2`. + + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.days(_to_java_column(col))) + + +@since(3.1) +def hours(col): + """ + Partition transform function: A transform for timestamps + to partition data into hours. + + >>> df.writeTo("catalog.db.table").partitionedBy( # doctest: +SKIP + ... hours("ts") + ... ).createOrReplace() + + .. warning:: + This function can be used only in combinatiion with + :py:meth:`~pyspark.sql.readwriter.DataFrameWriterV2.partitionedBy` + method of the `DataFrameWriterV2`. + + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.hours(_to_java_column(col))) + + +@since(3.1) +def bucket(numBuckets, col): + """ + Partition transform function: A transform for any type that partitions + by a hash of the input column. + + >>> df.writeTo("catalog.db.table").partitionedBy( # doctest: +SKIP + ... bucket(42, "ts") + ... ).createOrReplace() + + .. warning:: + This function can be used only in combinatiion with + :py:meth:`~pyspark.sql.readwriter.DataFrameWriterV2.partitionedBy` + method of the `DataFrameWriterV2`. + + """ + if not isinstance(numBuckets, (int, Column)): + raise TypeError( + "numBuckets should be a Column or and int, got {}".format(type(numBuckets)) + ) + + sc = SparkContext._active_spark_context + numBuckets = ( + _create_column_from_literal(numBuckets) + if isinstance(numBuckets, int) + else _to_java_column(numBuckets) + ) + return Column(sc._jvm.functions.bucket(numBuckets, _to_java_column(col))) + + # ---------------------------- User Defined Function ---------------------------------- @since(1.3) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index a83aece2e485d..6925adf567fb6 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -18,7 +18,7 @@ from py4j.java_gateway import JavaClass from pyspark import RDD, since -from pyspark.sql.column import _to_seq +from pyspark.sql.column import _to_seq, _to_java_column from pyspark.sql.types import * from pyspark.sql import utils from pyspark.sql.utils import to_str @@ -1075,6 +1075,145 @@ def jdbc(self, url, table, mode=None, properties=None): self.mode(mode)._jwrite.jdbc(url, table, jprop) +class DataFrameWriterV2(object): + """ + Interface used to write a class:`pyspark.sql.dataframe.DataFrame` + to external storage using the v2 API. + + .. versionadded:: 3.1.0 + """ + + def __init__(self, df, table): + self._df = df + self._spark = df.sql_ctx + self._jwriter = df._jdf.writeTo(table) + + @since(3.1) + def using(self, provider): + """ + Specifies a provider for the underlying output data source. + Spark's default catalog supports "parquet", "json", etc. + """ + self._jwriter.using(provider) + return self + + @since(3.1) + def option(self, key, value): + """ + Add a write option. + """ + self._jwriter.option(key, to_str(value)) + return self + + @since(3.1) + def options(self, **options): + """ + Add write options. + """ + options = {k: to_str(v) for k, v in options.items()} + self._jwriter.options(options) + return self + + @since(3.1) + def tableProperty(self, property, value): + """ + Add table property. + """ + self._jwriter.tableProperty(property, value) + return self + + @since(3.1) + def partitionedBy(self, col, *cols): + """ + Partition the output table created by `create`, `createOrReplace`, or `replace` using + the given columns or transforms. + + When specified, the table data will be stored by these values for efficient reads. + + For example, when a table is partitioned by day, it may be stored + in a directory layout like: + + * `table/day=2019-06-01/` + * `table/day=2019-06-02/` + + Partitioning is one of the most widely used techniques to optimize physical data layout. + It provides a coarse-grained index for skipping unnecessary data reads when queries have + predicates on the partitioned columns. In order for partitioning to work well, the number + of distinct values in each column should typically be less than tens of thousands. + + `col` and `cols` support only the following functions: + + * :py:func:`pyspark.sql.functions.years` + * :py:func:`pyspark.sql.functions.months` + * :py:func:`pyspark.sql.functions.days` + * :py:func:`pyspark.sql.functions.hours` + * :py:func:`pyspark.sql.functions.bucket` + + """ + col = _to_java_column(col) + cols = _to_seq(self._spark._sc, [_to_java_column(c) for c in cols]) + return self + + @since(3.1) + def create(self): + """ + Create a new table from the contents of the data frame. + + The new table's schema, partition layout, properties, and other configuration will be + based on the configuration set on this writer. + """ + self._jwriter.create() + + @since(3.1) + def replace(self): + """ + Replace an existing table with the contents of the data frame. + + The existing table's schema, partition layout, properties, and other configuration will be + replaced with the contents of the data frame and the configuration set on this writer. + """ + self._jwriter.replace() + + @since(3.1) + def createOrReplace(self): + """ + Create a new table or replace an existing table with the contents of the data frame. + + The output table's schema, partition layout, properties, + and other configuration will be based on the contents of the data frame + and the configuration set on this writer. + If the table exists, its configuration and data will be replaced. + """ + self._jwriter.createOrReplace() + + @since(3.1) + def append(self): + """ + Append the contents of the data frame to the output table. + """ + self._jwriter.append() + + @since(3.1) + def overwrite(self, condition): + """ + Overwrite rows matching the given filter condition with the contents of the data frame in + the output table. + """ + condition = _to_java_column(column) + self._jwriter.overwrite(condition) + + @since(3.1) + def overwritePartitions(self): + """ + Overwrite all partition for which the data frame contains at least one row with the contents + of the data frame in the output table. + + This operation is equivalent to Hive's `INSERT OVERWRITE ... PARTITION`, which replaces + partitions dynamically depending on the contents of the data frame. + """ + self._jwriter.overwritePartitions() + + def _test(): import doctest import os diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index 2530cc2ebf224..8e34d3865c9d8 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -19,6 +19,8 @@ import shutil import tempfile +from pyspark.sql.functions import col +from pyspark.sql.readwriter import DataFrameWriterV2 from pyspark.sql.types import * from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -163,6 +165,40 @@ def test_insert_into(self): self.assertEqual(6, self.spark.sql("select * from test_table").count()) +class ReadwriterV2Tests(ReusedSQLTestCase): + def test_api(self): + df = self.df + writer = df.writeTo("testcat.t") + self.assertIsInstance(writer, DataFrameWriterV2) + self.assertIsInstance(writer.option("property", "value"), DataFrameWriterV2) + self.assertIsInstance(writer.options(property="value"), DataFrameWriterV2) + self.assertIsInstance(writer.using("source"), DataFrameWriterV2) + self.assertIsInstance(writer.partitionedBy("id"), DataFrameWriterV2) + self.assertIsInstance(writer.partitionedBy(col("id")), DataFrameWriterV2) + self.assertIsInstance(writer.tableProperty("foo", "bar"), DataFrameWriterV2) + + def test_partitioning_functions(self): + import datetime + from pyspark.sql.functions import years, months, days, hours, bucket + + df = self.spark.createDataFrame( + [(1, datetime.datetime(2000, 1, 1), "foo")], + ("id", "ts", "value") + ) + + writer = df.writeTo("testcat.t") + + self.assertIsInstance(writer.partitionedBy(years("ts")), DataFrameWriterV2) + self.assertIsInstance(writer.partitionedBy(months("ts")), DataFrameWriterV2) + self.assertIsInstance(writer.partitionedBy(days("ts")), DataFrameWriterV2) + self.assertIsInstance(writer.partitionedBy(hours("ts")), DataFrameWriterV2) + self.assertIsInstance(writer.partitionedBy(bucket(11, "id")), DataFrameWriterV2) + self.assertIsInstance(writer.partitionedBy(bucket(11, col("id"))), DataFrameWriterV2) + self.assertIsInstance( + writer.partitionedBy(bucket(3, "id"), hours(col("ts"))), DataFrameWriterV2 + ) + + if __name__ == "__main__": import unittest from pyspark.sql.tests.test_readwriter import *