Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 32 additions & 14 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,21 @@ def test_infer_schema_not_enough_names(self):
df = self.spark.createDataFrame([["a", "b"]], ["col1"])
self.assertEqual(df.columns, ["col1", "_2"])

def test_infer_schema_fails(self):
with self.assertRaisesRegex(TypeError, "field a"):
self.spark.createDataFrame(
self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]),
schema=["a", "b"],
samplingRatio=0.99,
)
def test_infer_schema_upcast_int_to_string(self):
df = self.spark.createDataFrame(
self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]),
schema=["a", "b"],
samplingRatio=0.99,
)
self.assertEqual([Row(a="1", b=1), Row(a="x", b=1)], df.collect())

def test_infer_schema_upcast_float_to_string(self):
df = self.spark.createDataFrame([[1.33, 1], ["2.1", 1]], schema=["a", "b"])
self.assertEqual([Row(a="1.33", b=1), Row(a="2.1", b=1)], df.collect())

def test_infer_schema_upcast_boolean_to_string(self):
df = self.spark.createDataFrame([[True, 1], ["false", 1]], schema=["a", "b"])
self.assertEqual([Row(a="true", b=1), Row(a="false", b=1)], df.collect())

def test_infer_nested_schema(self):
NestedRow = Row("f1", "f2")
Expand Down Expand Up @@ -316,8 +324,10 @@ def test_infer_array_merge_element_types(self):
self.assertRaises(ValueError, lambda: self.spark.createDataFrame(data3))

# an array with conflicting types should raise an error
# in this case this is ArrayType(StringType) and ArrayType(NullType)
data4 = [ArrayRow([1, "1"], [None])]
self.assertRaises(TypeError, lambda: self.spark.createDataFrame(data4))
with self.assertRaisesRegex(ValueError, "types cannot be determined after inferring"):
self.spark.createDataFrame(data4)

def test_infer_array_element_type_empty(self):
# SPARK-39168: Test inferring array element type from all rows
Expand Down Expand Up @@ -840,8 +850,12 @@ def test_merge_type(self):
_merge_type(MapType(StringType(), LongType()), MapType(StringType(), LongType())),
MapType(StringType(), LongType()),
)
with self.assertRaisesRegex(TypeError, "key of map"):
_merge_type(MapType(StringType(), LongType()), MapType(DoubleType(), LongType()))

self.assertEqual(
_merge_type(MapType(StringType(), LongType()), MapType(DoubleType(), LongType())),
MapType(StringType(), LongType()),
)

with self.assertRaisesRegex(TypeError, "value of map"):
_merge_type(MapType(StringType(), LongType()), MapType(StringType(), DoubleType()))

Expand All @@ -865,11 +879,13 @@ def test_merge_type(self):
),
StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
)
with self.assertRaisesRegex(TypeError, "field f2 in field f1"):
self.assertEqual(
_merge_type(
StructType([StructField("f1", StructType([StructField("f2", LongType())]))]),
StructType([StructField("f1", StructType([StructField("f2", StringType())]))]),
)
),
StructType([StructField("f1", StructType([StructField("f2", StringType())]))]),
)

self.assertEqual(
_merge_type(
Expand Down Expand Up @@ -937,11 +953,13 @@ def test_merge_type(self):
),
StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
)
with self.assertRaisesRegex(TypeError, "key of map element in array field f1"):
self.assertEqual(
_merge_type(
StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))]),
)
),
StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]),
)

# test for SPARK-16542
def test_array_types(self):
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1648,6 +1648,10 @@ def new_name(n: str) -> str:
return a
elif isinstance(a, TimestampNTZType) and isinstance(b, TimestampType):
return b
elif isinstance(a, AtomicType) and isinstance(b, StringType):
return b
elif isinstance(a, StringType) and isinstance(b, AtomicType):
return a
elif type(a) is not type(b):
# TODO: type cast (such as int -> long)
raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b))))
Expand Down