diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 9eb75757bbafa..a1da8fcdd4152 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -421,8 +421,8 @@ def _createFromPandasWithArrow(self, pdf, schema): to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the data types will be used to coerce the data in Pandas to Arrow conversion. """ - from pyspark.serializers import ArrowSerializer - from pyspark.sql.types import from_arrow_schema, to_arrow_type, _cast_pandas_series_type + from pyspark.serializers import ArrowSerializer, _create_batch + from pyspark.sql.types import from_arrow_schema, to_arrow_type import pyarrow as pa # Slice the DataFrame into batches @@ -446,33 +446,10 @@ def _createFromPandasWithArrow(self, pdf, schema): else: schema = schema_from_arrow else: - batches = [] - for i, pdf_slice in enumerate(pdf_slices): - - # convert to series to pyarrow.Arrays to use mask when creating Arrow batches - arrs = [] - names = [] - for c, (_, series) in enumerate(pdf_slice.iteritems()): - field = schema[c] - names.append(field.name) - t = to_arrow_type(field.dataType) - try: - # NOTE: casting is not necessary with Arrow >= 0.7 - arrs.append(pa.Array.from_pandas(_cast_pandas_series_type(series, t), - mask=series.isnull(), type=t)) - except ValueError as e: - warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e)) - return None - batches.append(pa.RecordBatch.from_arrays(arrs, names)) - - # Verify schema of first batch, return None if not equal and fallback without Arrow - if i == 0: - schema_from_arrow = from_arrow_schema(batches[i].schema) - if schema != schema_from_arrow: - warnings.warn("Arrow will not be used in createDataFrame.\n" + - "Supplied schema: %s\n!=\nArrow schema: %s" - % (str(schema), str(schema_from_arrow))) - return None + arrow_types = [to_arrow_type(f.dataType) for f in schema.fields] + batches = [_create_batch([(c, t) + for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]) + for pdf_slice in pdf_slices] # Create the Spark DataFrame directly from the Arrow data and schema jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer()) @@ -580,10 +557,11 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr if has_pandas and isinstance(data, pandas.DataFrame): if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \ and len(data) > 0: - df = self._createFromPandasWithArrow(data, schema) - # Fallback to create DataFrame without arrow if return None - if df is not None: - return df + try: + return self._createFromPandasWithArrow(data, schema) + except Exception as e: + warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e)) + # Fallback to create DataFrame without arrow if raise some exception if schema is None: schema = [str(x) for x in data.columns] data = [r.tolist() for r in data.to_records(index=False)]