Skip to content

Commit 541dbc0

Browse files
ueshincloud-fan
authored andcommitted
[SPARK-23054][SQL][PYSPARK][FOLLOWUP] Use sqlType casting when casting PythonUserDefinedType to String.
## What changes were proposed in this pull request? This is a follow-up of #20246. If a UDT in Python doesn't have its corresponding Scala UDT, cast to string will be the raw string of the internal value, e.g. `"org.apache.spark.sql.catalyst.expressions.UnsafeArrayDataxxxxxxxx"` if the internal type is `ArrayType`. This pr fixes it by using its `sqlType` casting. ## How was this patch tested? Added a test and existing tests. Author: Takuya UESHIN <[email protected]> Closes #20306 from ueshin/issues/SPARK-23054/fup1. (cherry picked from commit 568055d) Signed-off-by: Wenchen Fan <[email protected]>
1 parent 225b1af commit 541dbc0

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

python/pyspark/sql/tests.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,17 @@ def test_union_with_udt(self):
11891189
]
11901190
)
11911191

1192+
def test_cast_to_string_with_udt(self):
1193+
from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
1194+
from pyspark.sql.functions import col
1195+
row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0))
1196+
schema = StructType([StructField("point", ExamplePointUDT(), False),
1197+
StructField("pypoint", PythonOnlyUDT(), False)])
1198+
df = self.spark.createDataFrame([row], schema)
1199+
1200+
result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head()
1201+
self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]'))
1202+
11921203
def test_column_operators(self):
11931204
ci = self.df.key
11941205
cs = self.df.value

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
282282
builder.append("]")
283283
builder.build()
284284
})
285+
case pudt: PythonUserDefinedType => castToString(pudt.sqlType)
285286
case udt: UserDefinedType[_] =>
286287
buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString))
287288
case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
@@ -838,6 +839,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
838839
|$evPrim = $buffer.build();
839840
""".stripMargin
840841
}
842+
case pudt: PythonUserDefinedType => castToStringCode(pudt.sqlType, ctx)
841843
case udt: UserDefinedType[_] =>
842844
val udtRef = ctx.addReferenceObj("udt", udt)
843845
(c, evPrim, evNull) => {

sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ private[sql] class ExamplePoint(val x: Double, val y: Double) extends Serializab
3434
case that: ExamplePoint => this.x == that.x && this.y == that.y
3535
case _ => false
3636
}
37+
38+
override def toString(): String = s"($x, $y)"
3739
}
3840

3941
/**

0 commit comments

Comments
 (0)