diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 8c2759a384..cdcfb42807 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.catalyst.optimizer.SimplifyExtractValueOps import org.apache.spark.sql.comet.{CometColumnarToRowExec, CometProjectExec} -import org.apache.spark.sql.execution.{InputAdapter, ProjectExec, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{InputAdapter, LocalTableScanExec, ProjectExec, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -2545,4 +2545,56 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } + + test("array_remove - ints") { + registerIntArray() + withSQLConf( + CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true", + CometConf.COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST.key -> "LocalTableScan") { + for (query <- Seq( + "select a, array_remove(a, 2) from int_array", + "select a, array_remove(a, -2) from int_array", + "select a, array_remove(a, null) from int_array")) { + checkSparkAnswerAndOperator(sql(query), classOf[LocalTableScanExec]) + } + } + } + + test("array_remove - strings") { + registerStringArray() + withSQLConf( + CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true", + CometConf.COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST.key -> "LocalTableScan") { + for (query <- Seq( + "select a, array_remove(a, 'two') from string_array", + "select a, array_remove(a, '') from string_array", + "select a, array_remove(a, 'four') from string_array", + "select a, array_remove(a, null) from string_array")) { + checkSparkAnswerAndOperator(sql(query), classOf[LocalTableScanExec]) + } + } + } + + private def registerIntArray(): Unit = { + val values: Seq[Option[Array[Option[Int]]]] = Seq( + Some(Array(Some(1), Some(2), Some(3))), + Some(Array(Some(1), Some(2), Some(2))), + Some(Array(None, Some(2), Some(2))), + None, + Some(Array()), + Some(Array(None, None))) + values.toDF("a").createOrReplaceTempView("int_array") + } + + private def registerStringArray(): Unit = { + val values: Seq[Option[Array[Option[String]]]] = Seq( + Some(Array(Some("one"), Some("two"), Some("three"))), + Some(Array(Some("one"), Some("two"), Some("two"))), + Some(Array(None, Some("two"), Some("two"))), + None, + Some(Array()), + Some(Array(None, None))) + values.toDF("a").createOrReplaceTempView("string_array") + } + }