Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects

import java.lang.reflect.{Method, Modifier}

import scala.collection.immutable
import scala.collection.mutable
import scala.collection.mutable.Builder
import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -938,6 +939,14 @@ case class MapObjects private(
executeFuncOnCollection(input).foreach(builder += _)
mutable.ArraySeq.make(builder.result())
}
case Some(cls) if classOf[immutable.ArraySeq[_]].isAssignableFrom(cls) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR, but this MapObjects expression gets more and more complicated. Can we rewrite it with RuntimeReplacable using StaticInvoke/Invoke so that we can implement it with pure scala?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let me give it a try.

implicit val tag: ClassTag[Any] = elementClassTag()
input => {
val builder = mutable.ArrayBuilder.make[Any]
builder.sizeHint(input.size)
executeFuncOnCollection(input).foreach(builder += _)
immutable.ArraySeq.unsafeWrapArray(builder.result())
}
case Some(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) =>
// Scala sequence
executeFuncOnCollection(_).toSeq
Expand Down Expand Up @@ -1108,7 +1117,20 @@ case class MapObjects private(
s"(${cls.getName}) ${classOf[mutable.ArraySeq[_]].getName}$$." +
s"MODULE$$.make($builder.result());"
)

case Some(cls) if classOf[immutable.ArraySeq[_]].isAssignableFrom(cls) =>
val tag = ctx.addReferenceObj("tag", elementClassTag())
val builderClassName = classOf[mutable.ArrayBuilder[_]].getName
val getBuilder = s"$builderClassName$$.MODULE$$.make($tag)"
val builder = ctx.freshName("collectionBuilder")
(
s"""
${classOf[Builder[_, _]].getName} $builder = $getBuilder;
$builder.sizeHint($dataLength);
""",
(genValue: String) => s"$builder.$$plus$$eq($genValue);",
s"(${cls.getName}) ${classOf[immutable.ArraySeq[_]].getName}$$." +
s"MODULE$$.unsafeWrapArray($builder.result());"
)
case Some(cls) if classOf[scala.collection.Seq[_]].isAssignableFrom(cls) ||
classOf[scala.collection.Set[_]].isAssignableFrom(cls) =>
// Scala sequence or set
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import java.sql.{Date, Timestamp}

import scala.collection.immutable
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
Expand Down Expand Up @@ -363,14 +364,18 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(result.asInstanceOf[java.util.List[_]].asScala == expected)
case a if classOf[mutable.ArraySeq[Int]].isAssignableFrom(a) =>
assert(result == mutable.ArraySeq.make[Int](expected.toArray))
case a if classOf[immutable.ArraySeq[Int]].isAssignableFrom(a) =>
assert(result.isInstanceOf[immutable.ArraySeq[_]])
assert(result == immutable.ArraySeq.unsafeWrapArray[Int](expected.toArray))
case s if classOf[Seq[_]].isAssignableFrom(s) =>
assert(result.asInstanceOf[Seq[_]] == expected)
case s if classOf[scala.collection.Set[_]].isAssignableFrom(s) =>
assert(result.asInstanceOf[scala.collection.Set[_]] == expected.toSet)
}
}

val customCollectionClasses = Seq(classOf[mutable.ArraySeq[Int]],
val customCollectionClasses = Seq(
classOf[mutable.ArraySeq[Int]], classOf[immutable.ArraySeq[Int]],
classOf[Seq[Int]], classOf[scala.collection.Set[Int]],
classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]],
classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]],
Expand All @@ -392,6 +397,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

Seq(
(Seq(1, 2, 3), ObjectType(classOf[mutable.ArraySeq[Int]])),
(Seq(1, 2, 3), ObjectType(classOf[immutable.ArraySeq[Int]])),
(Seq(1, 2, 3), ObjectType(classOf[Seq[Int]])),
(Array(1, 2, 3), ObjectType(classOf[Array[Int]])),
(Seq(1, 2, 3), ObjectType(classOf[Object])),
Expand Down
17 changes: 17 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.sql.Timestamp
import java.time.{Instant, LocalDate}
import java.time.format.DateTimeFormatter

import scala.collection.immutable
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

Expand Down Expand Up @@ -830,6 +831,22 @@ class UDFSuite extends QueryTest with SharedSparkSession {
Row(ArrayBuffer(100)))
}

test("SPARK-46586: UDF should not fail on immutable.ArraySeq") {
val myUdf1 = udf((a: immutable.ArraySeq[Int]) =>
immutable.ArraySeq.unsafeWrapArray[Int](Array(a.head + 99)))
checkAnswer(Seq(Array(1))
.toDF("col")
.select(myUdf1(Column("col"))),
Row(ArrayBuffer(100)))

val myUdf2 = udf((a: immutable.ArraySeq[Int]) =>
immutable.ArraySeq.unsafeWrapArray[Int](a.appended(5).appended(6).toArray))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: isn't it more common to use :+ to create new immutable collection with new elements?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let me improve it.

checkAnswer(Seq(Array(1, 2, 3))
.toDF("col")
.select(myUdf2(Column("col"))),
Row(ArrayBuffer(1, 2, 3, 5, 6)))
}

test("SPARK-34388: UDF name is propagated with registration for ScalaUDF") {
spark.udf.register("udf34388", udf((value: Int) => value > 2))
spark.sessionState.catalog.lookupFunction(
Expand Down