diff --git a/python/pyspark/sql/pandas/conversion.py b/python/pyspark/sql/pandas/conversion.py index 92ef7ce313026..d006c7ab46e63 100644 --- a/python/pyspark/sql/pandas/conversion.py +++ b/python/pyspark/sql/pandas/conversion.py @@ -297,8 +297,11 @@ class SparkConversionMixin(object): """ Min-in for the conversion from pandas to Spark. Currently, only :class:`SparkSession` can use this class. + pandasRDD=True creates a DataFrame from an RDD of pandas dataframes + (currently only supported using arrow) """ - def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): + def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True, + pandasRDD=False): from pyspark.sql import SparkSession assert isinstance(self, SparkSession) @@ -308,6 +311,14 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr timezone = self._wrapped._conf.sessionLocalTimeZone() + if self._wrapped._conf.arrowPySparkEnabled() and pandasRDD: + from pyspark.rdd import RDD + if not isinstance(data, RDD): + raise ValueError('pandasRDD is set but data is of type %s, expected RDD type.' + % type(data)) + # TODO: Support non-arrow conversion? might be *very* slow + return self._create_from_pandas_rdd_with_arrow(data, schema, timezone) + # If no schema supplied by user then get the names of columns only if schema is None: schema = [str(x) if not isinstance(x, str) else @@ -353,30 +364,8 @@ def _convert_from_pandas(self, pdf, schema, timezone): assert isinstance(self, SparkSession) if timezone is not None: - from pyspark.sql.pandas.types import _check_series_convert_timestamps_tz_local - copied = False - if isinstance(schema, StructType): - for field in schema: - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if isinstance(field.dataType, TimestampType): - s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone) - if s is not pdf[field.name]: - if not copied: - # Copy once if the series is modified to prevent the original - # Pandas DataFrame from being updated - pdf = pdf.copy() - copied = True - pdf[field.name] = s - else: - for column, series in pdf.iteritems(): - s = _check_series_convert_timestamps_tz_local(series, timezone) - if s is not series: - if not copied: - # Copy once if the series is modified to prevent the original - # Pandas DataFrame from being updated - pdf = pdf.copy() - copied = True - pdf[column] = s + from pyspark.sql.pandas.types import _check_dataframe_covert_timestamps_tz_local + pdf = _check_dataframe_covert_timestamps_tz_local(pdf, timezone, schema) # Convert pandas.DataFrame to list of numpy records np_records = pdf.to_records(index=False) @@ -421,6 +410,39 @@ def _get_numpy_record_dtype(self, rec): record_type_list.append((str(col_names[i]), curr_type)) return np.dtype(record_type_list) if has_rec_fix else None + def _create_from_pandas_rdd_with_arrow(self, prdd, schema, timezone): + """ + Create a DataFrame from an RDD of pandas.DataFrames by converting each DF to one or more + Arrow RecordBatches which are then sent to the JVM. + If a schema is passed in, the data types will be used to coerce the data in + Pandas to Arrow conversion. + """ + import pandas as pd + import pyarrow as pa + + safecheck = self._wrapped._conf.arrowSafeTypeConversion() + + # In case no schema is passed, extract inferred schema from the first record batch + from pyspark.sql.pandas.types import from_arrow_schema + if schema is None: + schema = from_arrow_schema(pa.Schema.from_pandas(prdd.first())) + + # Convert to an RDD of arrow record batches + rb_rdd = (prdd. + filter(lambda x: isinstance(x, pd.DataFrame)). + map(lambda x: _dataframe_to_arrow_record_batch(x, + timezone=timezone, + schema=schema, + safecheck=safecheck))) + + # Create Spark DataFrame from Arrow record batches RDD + from pyspark.sql.dataframe import DataFrame + jrdd = rb_rdd._to_java_object_rdd() + jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), self._wrapped._jsqlContext) + df = DataFrame(jdf, self._wrapped) + df._schema = schema + return df + def _create_from_pandas_with_arrow(self, pdf, schema, timezone): """ Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting @@ -491,6 +513,67 @@ def create_RDD_server(): return df +def _sanitize_arrow_schema(schema): + import pyarrow as pa + import re + sanitized_fields = [] + + # Convert pyarrow schema to a spark compatible one + _SPARK_DISALLOWED_CHARS = re.compile('[ ,;{}()\n\t=]') + + def _sanitized_spark_field_name(name): + return _SPARK_DISALLOWED_CHARS.sub('_', name) + + for field in schema: + name = field.name + sanitized_name = _sanitized_spark_field_name(name) + + if sanitized_name != name: + sanitized_field = pa.field(sanitized_name, field.type, + field.nullable, field.metadata) + sanitized_fields.append(sanitized_field) + else: + sanitized_fields.append(field) + + new_schema = pa.schema(sanitized_fields, metadata=schema.metadata) + return new_schema + + +def _dataframe_to_arrow_record_batch(pdf, schema=None, timezone=None, safecheck=False): + """ + Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting + to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the + data types will be used to coerce the data in Pandas to Arrow conversion. + """ + import pyarrow as pa + from pyspark.sql.pandas.types import to_arrow_schema, from_arrow_schema + from pyspark.sql.pandas.utils import require_minimum_pandas_version, \ + require_minimum_pyarrow_version + + require_minimum_pandas_version() + require_minimum_pyarrow_version() + + # Determine arrow types to coerce data when creating batches + if schema is not None: + arrow_schema = to_arrow_schema(schema) + else: + # Any timestamps must be coerced to be compatible with Spark + arrow_schema = to_arrow_schema(from_arrow_schema(pa.Schema.from_pandas(pdf))) + + # Sanitize arrow schema for spark compatibility + arrow_schema = _sanitize_arrow_schema(arrow_schema) + + # Create an Arrow record batch, one batch per DF + from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer + arrow_data = [(pdf[col_name], arrow_type) for col_name, arrow_type + in zip(arrow_schema.names, arrow_schema.types)] + + col_by_name = True # col by name only applies to StructType columns, can't happen here + ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name) + + return bytearray(ser._create_batch(arrow_data).serialize()) + + def _test(): import doctest from pyspark.sql import SparkSession diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 489b46691a135..925e47083a3a2 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -328,6 +328,42 @@ def _check_series_convert_timestamps_tz_local(s, timezone): return _check_series_convert_timestamps_localize(s, timezone, None) +def _check_dataframe_covert_timestamps_tz_local(pdf, timezone, schema=None): + """ + Convert timestamp to timezone-naive in the specified timezone or local timezone + + :param pdf: a pandas.DataFrame + :param timezone: the timezone to convert from. if None then use local timezone + :param schema: an optional spark schema that defines which timestamp columns to inspect + :return pandas.DataFrame where if it is a timestamp, has been converted to tz-naive + """ + copied = False + if isinstance(schema, StructType): + for field in schema: + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if isinstance(field.dataType, TimestampType): + s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone) + if s is not pdf[field.name]: + if not copied: + # Copy once if the series is modified to prevent the original + # Pandas DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[field.name] = s + else: + for column, series in pdf.iteritems(): + s = _check_series_convert_timestamps_tz_local(series, timezone) + if s is not series: + if not copied: + # Copy once if the series is modified to prevent the original + # Pandas DataFrame from being updated + pdf = pdf.copy() + copied = True + pdf[column] = s + + return pdf + + def _convert_map_items_to_dict(s): """ Convert a series with items as list of (key, value), as made from an Arrow column of map type, diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 740ceb31f7d16..c6badf01063ab 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -552,9 +552,11 @@ def _create_shell_session(): return SparkSession.builder.getOrCreate() - def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True): + def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True, + pandasRDD=False): """ - Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`. + Creates a :class:`DataFrame` from an :class:`RDD`, an :class:`RDD[pandas.DataFrame]`, + a list or a :class:`pandas.DataFrame`. When ``schema`` is a list of column names, the type of each column will be inferred from ``data``. @@ -580,9 +582,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr Parameters ---------- data : :class:`RDD` or iterable - an RDD of any kind of SQL data representation (:class:`Row`, - :class:`tuple`, ``int``, ``boolean``, etc.), or :class:`list`, or - :class:`pandas.DataFrame`. + an RDD of any kind of SQL data representation(e.g. :class:`Row`, + :class:`tuple`, ``int``, ``boolean``, :class:`pandas.DataFrame`, etc.), + or :class:`list`, or :class:`pandas.DataFrame`. schema : :class:`pyspark.sql.types.DataType`, str or list, optional a :class:`pyspark.sql.types.DataType` or a datatype string or a list of column names, default is None. The data type string format equals to @@ -594,6 +596,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr the sample ratio of rows used for inferring verifySchema : bool, optional verify data types of every row against schema. Enabled by default. + pandasRDD: bool, optional + indicates that the input RDD contains pandas.DataFrame. Returns ------- @@ -637,6 +641,15 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr >>> df3.collect() [Row(name='Alice', age=1)] + >>> # doctest: +SKIP + ... prdd = sc.range(0, 10).map(lambda x: pandas.DataFrame([[x,]*4], columns=list('ABCD'))) + ... df4 = spark.createDataFrame(prdd, schema=None, pandasRDD=True) + ... df4.collect() + [Row(A=0, B=0, C=0, D=0), + Row(A=1, B=1, C=1, D=1), + Row(A=2, B=2, C=2, D=2), + Row(A=3, B=3, C=3, D=3)] + >>> spark.createDataFrame(df.toPandas()).collect() # doctest: +SKIP [Row(name='Alice', age=1)] >>> spark.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP @@ -668,10 +681,10 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr has_pandas = True except Exception: has_pandas = False - if has_pandas and isinstance(data, pandas.DataFrame): + if has_pandas and (isinstance(data, pandas.DataFrame) or pandasRDD): # Create a DataFrame from pandas DataFrame. return super(SparkSession, self).createDataFrame( - data, schema, samplingRatio, verifySchema) + data, schema, samplingRatio, verifySchema, pandasRDD) return self._create_dataframe(data, schema, samplingRatio, verifySchema) def _create_dataframe(self, data, schema, samplingRatio, verifySchema): diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index cca9ec406506a..81007fd20f047 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -349,6 +349,17 @@ def test_schema_conversion_roundtrip(self): schema_rt = from_arrow_schema(arrow_schema) self.assertEqual(self.schema, schema_rt) + def test_createDataFrame_from_pandas_rdd(self): + pdfs = [self.create_pandas_data_frame() for _ in range(4)] + prdd = self.sc.parallelize(pdfs) + + df_from_rdd = self.spark.createDataFrame(prdd, schema=self.schema, pandasRDD=True) + df_from_pdf = self.spark.createDataFrame(pd.concat(pdfs), schema=self.schema) + + result_prdd = df_from_rdd.toPandas() + result_single_pdf = df_from_pdf.toPandas() + assert_frame_equal(result_prdd, result_single_pdf) + def test_createDataFrame_with_array_type(self): pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]}) df, df_arrow = self._createDataFrame_toggle(pdf)