Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 11 additions & 33 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,8 +421,8 @@ def _createFromPandasWithArrow(self, pdf, schema):
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowSerializer
from pyspark.sql.types import from_arrow_schema, to_arrow_type, _cast_pandas_series_type
from pyspark.serializers import ArrowSerializer, _create_batch
from pyspark.sql.types import from_arrow_schema, to_arrow_type
import pyarrow as pa

# Slice the DataFrame into batches
Expand All @@ -446,33 +446,10 @@ def _createFromPandasWithArrow(self, pdf, schema):
else:
schema = schema_from_arrow
else:
batches = []
for i, pdf_slice in enumerate(pdf_slices):

# convert to series to pyarrow.Arrays to use mask when creating Arrow batches
arrs = []
names = []
for c, (_, series) in enumerate(pdf_slice.iteritems()):
field = schema[c]
names.append(field.name)
t = to_arrow_type(field.dataType)
try:
# NOTE: casting is not necessary with Arrow >= 0.7
arrs.append(pa.Array.from_pandas(_cast_pandas_series_type(series, t),
mask=series.isnull(), type=t))
except ValueError as e:
warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e))
return None
batches.append(pa.RecordBatch.from_arrays(arrs, names))

# Verify schema of first batch, return None if not equal and fallback without Arrow
if i == 0:
schema_from_arrow = from_arrow_schema(batches[i].schema)
if schema != schema_from_arrow:
warnings.warn("Arrow will not be used in createDataFrame.\n" +
"Supplied schema: %s\n!=\nArrow schema: %s"
% (str(schema), str(schema_from_arrow)))
return None
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
batches = [_create_batch([(c, t)
for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, looks this is only the place using zip. Not a big deal but I think we are safe to replace

from itertools import imap as map
from itertools import izip as zip, imap as map

We could change this after merging this first too.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

for pdf_slice in pdf_slices]

# Create the Spark DataFrame directly from the Arrow data and schema
jrdd = self._sc._serialize_to_jvm(batches, len(batches), ArrowSerializer())
Expand Down Expand Up @@ -580,10 +557,11 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
if has_pandas and isinstance(data, pandas.DataFrame):
if self.conf.get("spark.sql.execution.arrow.enabled", "false").lower() == "true" \
and len(data) > 0:
df = self._createFromPandasWithArrow(data, schema)
# Fallback to create DataFrame without arrow if return None
if df is not None:
return df
try:
return self._createFromPandasWithArrow(data, schema)
except Exception as e:
warnings.warn("Arrow will not be used in createDataFrame: %s" % str(e))
# Fallback to create DataFrame without arrow if raise some exception
if schema is None:
schema = [str(x) for x in data.columns]
data = [r.tolist() for r in data.to_records(index=False)]
Expand Down