diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 15b9f2fdeb0f0..42562e1fb9c46 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -143,10 +143,7 @@ def _create_batch(self, series): import pandas as pd import pyarrow as pa from pyspark.sql.pandas.types import _check_series_convert_timestamps_internal - try: - from pandas import CategoricalDtype - except ImportError: - from pandas.api.types import CategoricalDtype + from pandas.api.types import is_categorical # Make input conform to [(series1, type1), (series2, type2), ...] if not isinstance(series, (list, tuple)) or \ (len(series) == 2 and isinstance(series[1], pa.DataType)): @@ -158,7 +155,7 @@ def create_array(s, t): # Ensure timestamp series are in expected form for Spark internal representation if t is not None and pa.types.is_timestamp(t): s = _check_series_convert_timestamps_internal(s, self._timezone) - elif type(s.dtype) == CategoricalDtype: + elif is_categorical(s.dtype): # Note: This can be removed once minimum pyarrow version is >= 0.16.1 s = s.astype(s.dtypes.categories.dtype) try: