diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index bfc0e0c5726e8..e67012319904a 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -145,6 +145,7 @@ def test_literal_integers(self): cdf.select(CF.lit(JVM_LONG_MIN - 1)).show() def test_cast(self): + # SPARK-41412: test basic Column.cast df = self.connect.read.table(self.tbl_name) df2 = self.spark.read.table(self.tbl_name) @@ -152,22 +153,25 @@ def test_cast(self): df.select(df.id.cast("string")).toPandas(), df2.select(df2.id.cast("string")).toPandas() ) - for x in [ - StringType(), - BinaryType(), - ShortType(), - IntegerType(), - LongType(), - FloatType(), - DoubleType(), - ByteType(), - DecimalType(10, 2), - BooleanType(), - DayTimeIntervalType(), - ]: - self.assert_eq( - df.select(df.id.cast(x)).toPandas(), df2.select(df2.id.cast(x)).toPandas() - ) + # Test if the arguments can be passed properly. + # Do not need to check individual behaviour for the ANSI mode thoroughly. + with self.sql_conf({"spark.sql.ansi.enabled": False}): + for x in [ + StringType(), + BinaryType(), + ShortType(), + IntegerType(), + LongType(), + FloatType(), + DoubleType(), + ByteType(), + DecimalType(10, 2), + BooleanType(), + DayTimeIntervalType(), + ]: + self.assert_eq( + df.select(df.id.cast(x)).toPandas(), df2.select(df2.id.cast(x)).toPandas() + ) def test_unsupported_functions(self): # SPARK-41225: Disable unsupported functions.