Skip to content
Closed
3 changes: 3 additions & 0 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def create_array(s, t):
# Ensure timestamp series are in expected form for Spark internal representation
if t is not None and pa.types.is_timestamp(t):
s = _check_series_convert_timestamps_internal(s, self._timezone)
elif type(s.dtype) == pd.CategoricalDtype:
# Note: This can be removed once minimum pyarrow version is >= 0.16.1
s = s.astype(s.dtypes.categories.dtype)
try:
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
except pa.ArrowException as e:
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def from_arrow_type(at):
return StructType(
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
for field in at])
elif types.is_dictionary(at):
spark_type = from_arrow_type(at.value_type)
else:
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
return spark_type
Expand Down
26 changes: 26 additions & 0 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,32 @@ def run_test(num_records, num_parts, max_records, use_delay=False):
for case in cases:
run_test(*case)

def test_createDateFrame_with_category_type(self):
pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]})
pdf["B"] = pdf["A"].astype('category')
category_first_element = dict(enumerate(pdf['B'].cat.categories))[0]

with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):
arrow_df = self.spark.createDataFrame(pdf)
arrow_type = arrow_df.dtypes[1][1]
result_arrow = arrow_df.toPandas()
arrow_first_category_element = result_arrow["B"][0]

with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
df = self.spark.createDataFrame(pdf)
spark_type = df.dtypes[1][1]
result_spark = df.toPandas()
spark_first_category_element = result_spark["B"][0]

assert_frame_equal(result_spark, result_arrow)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add an assert that the Spark DataFrame has column "B" as a string type?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, move other test checks here too.

# ensure original category elements are string
assert isinstance(category_first_element, str)
# spark data frame and arrow execution mode enabled data frame type must match pandas
assert spark_type == arrow_type == 'string'
assert isinstance(arrow_first_category_element, str)
assert isinstance(spark_first_category_element, str)


@unittest.skipIf(
not have_pandas or not have_pyarrow,
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/sql/tests/test_pandas_udf_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,27 @@ def test_timestamp_dst(self):
result = df.withColumn('time', foo_udf(df.time))
self.assertEquals(df.collect(), result.collect())

def test_udf_category_type(self):

@pandas_udf('string')
def to_category_func(x):
return x.astype('category')

pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]})
df = self.spark.createDataFrame(pdf)
df = df.withColumn("B", to_category_func(df['A']))
result_spark = df.toPandas()

spark_type = df.dtypes[1][1]
# spark data frame and arrow execution mode enabled data frame type must match pandas
assert spark_type == 'string'

# Check result value of column 'B' must be equal to column 'A'
for i in range(0, len(result_spark["A"])):
assert result_spark["A"][i] == result_spark["B"][i]
assert isinstance(result_spark["A"][i], str)
assert isinstance(result_spark["B"][i], str)

@unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.")
def test_type_annotation(self):
# Regression test to check if type hints can be used. See SPARK-23569.
Expand Down