diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 3228a2bd8cda..76ec85120e86 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -40,23 +40,27 @@ import org.apache.spark.sql.types.StringType * * To register Scala UDF in SQL: * {{{ - * registerTestUDF(TestScalaUDF(name = "udf_name"), spark) + * val scalaTestUDF = TestScalaUDF(name = "udf_name") + * registerTestUDF(scalaTestUDF, spark) * }}} * * To register Python UDF in SQL: * {{{ - * registerTestUDF(TestPythonUDF(name = "udf_name"), spark) + * val pythonTestUDF = TestPythonUDF(name = "udf_name") + * registerTestUDF(pythonTestUDF, spark) * }}} * * To register Scalar Pandas UDF in SQL: * {{{ - * registerTestUDF(TestScalarPandasUDF(name = "udf_name"), spark) + * val pandasTestUDF = TestScalarPandasUDF(name = "udf_name") + * registerTestUDF(pandasTestUDF, spark) * }}} * * To use it in Scala API and SQL: * {{{ * sql("SELECT udf_name(1)") - * spark.select(expr("udf_name(1)") + * spark.range(10).select(expr("udf_name(id)") + * spark.range(10).select(pandasTestUDF($"id")) * }}} */ object IntegratedUDFTestUtils extends SQLHelper { @@ -64,8 +68,9 @@ object IntegratedUDFTestUtils extends SQLHelper { private lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "") private lazy val sparkHome = if (sys.props.contains(Tests.IS_TESTING.key)) { - assert(sys.props.contains("spark.test.home"), "spark.test.home is not set.") - sys.props("spark.test.home") + assert(sys.props.contains("spark.test.home") || + sys.env.contains("SPARK_HOME"), "spark.test.home or SPARK_HOME is not set.") + sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) } else { assert(sys.env.contains("SPARK_HOME"), "SPARK_HOME is not set.") sys.env("SPARK_HOME") @@ -186,14 +191,18 @@ object IntegratedUDFTestUtils extends SQLHelper { /** * A base trait for various UDFs defined in this object. */ - sealed trait TestUDF + sealed trait TestUDF { + def apply(exprs: Column*): Column + + val prettyName: String + } /** * A Python UDF that takes one column and returns a string column. * Equivalent to `udf(lambda x: str(x), "string")` */ case class TestPythonUDF(name: String) extends TestUDF { - lazy val udf = UserDefinedPythonFunction( + private[IntegratedUDFTestUtils] lazy val udf = UserDefinedPythonFunction( name = name, func = PythonFunction( command = pythonFunc, @@ -206,6 +215,10 @@ object IntegratedUDFTestUtils extends SQLHelper { dataType = StringType, pythonEvalType = PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) + + def apply(exprs: Column*): Column = udf(exprs: _*) + + val prettyName: String = "Regular Python UDF" } /** @@ -213,7 +226,7 @@ object IntegratedUDFTestUtils extends SQLHelper { * Equivalent to `pandas_udf(lambda x: x.apply(str), "string", PandasUDFType.SCALAR)`. */ case class TestScalarPandasUDF(name: String) extends TestUDF { - lazy val udf = UserDefinedPythonFunction( + private[IntegratedUDFTestUtils] lazy val udf = UserDefinedPythonFunction( name = name, func = PythonFunction( command = pandasFunc, @@ -226,6 +239,10 @@ object IntegratedUDFTestUtils extends SQLHelper { dataType = StringType, pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF, udfDeterministic = true) + + def apply(exprs: Column*): Column = udf(exprs: _*) + + val prettyName: String = "Scalar Pandas UDF" } /** @@ -233,10 +250,14 @@ object IntegratedUDFTestUtils extends SQLHelper { * Equivalent to `udf((input: Any) => input.toString)`. */ case class TestScalaUDF(name: String) extends TestUDF { - lazy val udf = SparkUserDefinedFunction( + private[IntegratedUDFTestUtils] lazy val udf = SparkUserDefinedFunction( (input: Any) => input.toString, StringType, inputSchemas = Seq.fill(1)(None)) + + def apply(exprs: Column*): Column = udf(exprs: _*) + + val prettyName: String = "Scala UDF" } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 40232d4213fc..1c8cf6403c6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -383,24 +383,13 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { val testCaseName = absPath.stripPrefix(inputFilePath).stripPrefix(File.separator) if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udf")) { - Seq( + Seq(TestScalaUDF("udf"), TestPythonUDF("udf"), TestScalarPandasUDF("udf")).map { udf => UDFTestCase( - s"$testCaseName - Scala UDF", + s"$testCaseName - ${udf.prettyName}", absPath, resultFile, - TestScalaUDF(name = "udf")), - - UDFTestCase( - s"$testCaseName - Python UDF", - absPath, - resultFile, - TestPythonUDF(name = "udf")), - - UDFTestCase( - s"$testCaseName - Scalar Pandas UDF", - absPath, - resultFile, - TestScalarPandasUDF(name = "udf"))) + udf) + } } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}pgSQL")) { PgSQLTestCase(testCaseName, absPath, resultFile) :: Nil } else {