diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 4214dd71bb8e..b1d2eccea436 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -235,13 +235,21 @@ def test_infer_schema_not_enough_names(self): df = self.spark.createDataFrame([["a", "b"]], ["col1"]) self.assertEqual(df.columns, ["col1", "_2"]) - def test_infer_schema_fails(self): - with self.assertRaisesRegex(TypeError, "field a"): - self.spark.createDataFrame( - self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]), - schema=["a", "b"], - samplingRatio=0.99, - ) + def test_infer_schema_upcast_int_to_string(self): + df = self.spark.createDataFrame( + self.spark.sparkContext.parallelize([[1, 1], ["x", 1]]), + schema=["a", "b"], + samplingRatio=0.99, + ) + self.assertEqual([Row(a="1", b=1), Row(a="x", b=1)], df.collect()) + + def test_infer_schema_upcast_float_to_string(self): + df = self.spark.createDataFrame([[1.33, 1], ["2.1", 1]], schema=["a", "b"]) + self.assertEqual([Row(a="1.33", b=1), Row(a="2.1", b=1)], df.collect()) + + def test_infer_schema_upcast_boolean_to_string(self): + df = self.spark.createDataFrame([[True, 1], ["false", 1]], schema=["a", "b"]) + self.assertEqual([Row(a="true", b=1), Row(a="false", b=1)], df.collect()) def test_infer_nested_schema(self): NestedRow = Row("f1", "f2") @@ -316,8 +324,10 @@ def test_infer_array_merge_element_types(self): self.assertRaises(ValueError, lambda: self.spark.createDataFrame(data3)) # an array with conflicting types should raise an error + # in this case this is ArrayType(StringType) and ArrayType(NullType) data4 = [ArrayRow([1, "1"], [None])] - self.assertRaises(TypeError, lambda: self.spark.createDataFrame(data4)) + with self.assertRaisesRegex(ValueError, "types cannot be determined after inferring"): + self.spark.createDataFrame(data4) def test_infer_array_element_type_empty(self): # SPARK-39168: Test inferring array element type from all rows @@ -840,8 +850,12 @@ def test_merge_type(self): _merge_type(MapType(StringType(), LongType()), MapType(StringType(), LongType())), MapType(StringType(), LongType()), ) - with self.assertRaisesRegex(TypeError, "key of map"): - _merge_type(MapType(StringType(), LongType()), MapType(DoubleType(), LongType())) + + self.assertEqual( + _merge_type(MapType(StringType(), LongType()), MapType(DoubleType(), LongType())), + MapType(StringType(), LongType()), + ) + with self.assertRaisesRegex(TypeError, "value of map"): _merge_type(MapType(StringType(), LongType()), MapType(StringType(), DoubleType())) @@ -865,11 +879,13 @@ def test_merge_type(self): ), StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), ) - with self.assertRaisesRegex(TypeError, "field f2 in field f1"): + self.assertEqual( _merge_type( StructType([StructField("f1", StructType([StructField("f2", LongType())]))]), StructType([StructField("f1", StructType([StructField("f2", StringType())]))]), - ) + ), + StructType([StructField("f1", StructType([StructField("f2", StringType())]))]), + ) self.assertEqual( _merge_type( @@ -937,11 +953,13 @@ def test_merge_type(self): ), StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), ) - with self.assertRaisesRegex(TypeError, "key of map element in array field f1"): + self.assertEqual( _merge_type( StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), StructType([StructField("f1", ArrayType(MapType(DoubleType(), LongType())))]), - ) + ), + StructType([StructField("f1", ArrayType(MapType(StringType(), LongType())))]), + ) # test for SPARK-16542 def test_array_types(self): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 365c903487ce..9cb17e85540f 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1648,6 +1648,10 @@ def new_name(n: str) -> str: return a elif isinstance(a, TimestampNTZType) and isinstance(b, TimestampType): return b + elif isinstance(a, AtomicType) and isinstance(b, StringType): + return b + elif isinstance(a, StringType) and isinstance(b, AtomicType): + return a elif type(a) is not type(b): # TODO: type cast (such as int -> long) raise TypeError(new_msg("Can not merge type %s and %s" % (type(a), type(b))))