diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index f248afa3d839d..ee04c94cbd5d8 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -1139,7 +1139,12 @@ def createDataFrame( # type: ignore[misc] require_minimum_pandas_version() if data.ndim not in [1, 2]: raise ValueError("NumPy array input should be of 1 or 2 dimensions.") - column_names = ["value"] if data.ndim == 1 else ["_1", "_2"] + + if data.ndim == 1 or data.shape[1] == 1: + column_names = ["value"] + else: + column_names = ["_%s" % i for i in range(1, data.shape[1] + 1)] + if schema is None and not self._jconf.arrowPySparkEnabled(): # Construct `schema` from `np.dtype` of the input NumPy array # TODO: Apply the logic below when self._jconf.arrowPySparkEnabled() is True diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index fdba431726c83..6083f31ac81b9 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -188,8 +188,10 @@ def create_np_arrs(self): return ( [np.array([1, 2]).astype(t) for t in int_dtypes] + [np.array([0.1, 0.2]).astype(t) for t in float_dtypes] - + [np.array([[1, 2], [3, 4]]).astype(t) for t in int_dtypes] - + [np.array([[0.1, 0.2], [0.3, 0.4]]).astype(t) for t in float_dtypes] + + [np.array([[1], [2]]).astype(t) for t in int_dtypes] + + [np.array([[0.1], [0.2]]).astype(t) for t in float_dtypes] + + [np.array([[1, 1, 1], [2, 2, 2]]).astype(t) for t in int_dtypes] + + [np.array([[0.1, 0.1, 0.1], [0.2, 0.2, 0.2]]).astype(t) for t in float_dtypes] ) def test_toPandas_fallback_enabled(self): @@ -510,9 +512,11 @@ def test_schema_conversion_roundtrip(self): def test_createDataFrame_with_ndarray(self): dtypes = ["tinyint", "smallint", "int", "bigint", "float", "double"] - expected_dtypes = [[("value", t)] for t in dtypes] + [ - [("_1", t), ("_2", t)] for t in dtypes - ] + expected_dtypes = ( + [[("value", t)] for t in dtypes] + + [[("value", t)] for t in dtypes] + + [[("_1", t), ("_2", t), ("_3", t)] for t in dtypes] + ) arrs = self.create_np_arrs for arr, dtypes in zip(arrs, expected_dtypes):