diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala index f0083e95fc2c..d6c22c00af7b 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -38,7 +38,9 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { test("roundtrip in to_avro and from_avro - int and string") { val df = spark.range(10).select('id, 'id.cast("string").as("str")) - val avroDF = df.select(to_avro('id).as("a"), to_avro('str).as("b")) + val avroDF = df.select( + functions.to_avro('id).as("a"), + functions.to_avro('str).as("b")) val avroTypeLong = s""" |{ | "type": "int", @@ -51,12 +53,14 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { | "name": "str" |} """.stripMargin - checkAnswer(avroDF.select(from_avro('a, avroTypeLong), from_avro('b, avroTypeStr)), df) + checkAnswer(avroDF.select( + functions.from_avro('a, avroTypeLong), + functions.from_avro('b, avroTypeStr)), df) } test("roundtrip in to_avro and from_avro - struct") { val df = spark.range(10).select(struct('id, 'id.cast("string").as("str")).as("struct")) - val avroStructDF = df.select(to_avro('struct).as("avro")) + val avroStructDF = df.select(functions.to_avro('struct).as("avro")) val avroTypeStruct = s""" |{ | "type": "record", @@ -67,13 +71,14 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { | ] |} """.stripMargin - checkAnswer(avroStructDF.select(from_avro('avro, avroTypeStruct)), df) + checkAnswer(avroStructDF.select( + functions.from_avro('avro, avroTypeStruct)), df) } test("handle invalid input in from_avro") { val count = 10 val df = spark.range(count).select(struct('id, 'id.as("id2")).as("struct")) - val avroStructDF = df.select(to_avro('struct).as("avro")) + val avroStructDF = df.select(functions.to_avro('struct).as("avro")) val avroTypeStruct = s""" |{ | "type": "record", @@ -87,7 +92,7 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { intercept[SparkException] { avroStructDF.select( - org.apache.spark.sql.avro.functions.from_avro( + functions.from_avro( 'avro, avroTypeStruct, Map("mode" -> "FAILFAST").asJava)).collect() } @@ -95,7 +100,7 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { val expected = (0 until count).map(_ => Row(Row(null, null))) checkAnswer( avroStructDF.select( - org.apache.spark.sql.avro.functions.from_avro( + functions.from_avro( 'avro, avroTypeStruct, Map("mode" -> "PERMISSIVE").asJava)), expected) } @@ -115,8 +120,8 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { | }, "null" ] |}, "null" ] """.stripMargin - val readBackOne = dfOne.select(to_avro($"array").as("avro")) - .select(from_avro($"avro", avroTypeArrStruct).as("array")) + val readBackOne = dfOne.select(functions.to_avro($"array").as("avro")) + .select(functions.from_avro($"avro", avroTypeArrStruct).as("array")) checkAnswer(dfOne, readBackOne) } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/DeprecatedAvroFunctionsSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/DeprecatedAvroFunctionsSuite.scala new file mode 100644 index 000000000000..cdfa1b118b18 --- /dev/null +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/DeprecatedAvroFunctionsSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.avro + +import java.io.ByteArrayOutputStream + +import org.apache.avro.Schema +import org.apache.avro.generic.{GenericDatumWriter, GenericRecord, GenericRecordBuilder} +import org.apache.avro.io.EncoderFactory + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.execution.LocalTableScanExec +import org.apache.spark.sql.functions.{col, struct} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +@deprecated("This test suite will be removed.", "3.0.0") +class DeprecatedAvroFunctionsSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("roundtrip in to_avro and from_avro - int and string") { + val df = spark.range(10).select('id, 'id.cast("string").as("str")) + + val avroDF = df.select(to_avro('id).as("a"), to_avro('str).as("b")) + val avroTypeLong = s""" + |{ + | "type": "int", + | "name": "id" + |} + """.stripMargin + val avroTypeStr = s""" + |{ + | "type": "string", + | "name": "str" + |} + """.stripMargin + checkAnswer(avroDF.select(from_avro('a, avroTypeLong), from_avro('b, avroTypeStr)), df) + } + + test("roundtrip in to_avro and from_avro - struct") { + val df = spark.range(10).select(struct('id, 'id.cast("string").as("str")).as("struct")) + val avroStructDF = df.select(to_avro('struct).as("avro")) + val avroTypeStruct = s""" + |{ + | "type": "record", + | "name": "struct", + | "fields": [ + | {"name": "col1", "type": "long"}, + | {"name": "col2", "type": "string"} + | ] + |} + """.stripMargin + checkAnswer(avroStructDF.select(from_avro('avro, avroTypeStruct)), df) + } + + test("roundtrip in to_avro and from_avro - array with null") { + val dfOne = Seq(Tuple1(Tuple1(1) :: Nil), Tuple1(null :: Nil)).toDF("array") + val avroTypeArrStruct = s""" + |[ { + | "type" : "array", + | "items" : [ { + | "type" : "record", + | "name" : "x", + | "fields" : [ { + | "name" : "y", + | "type" : "int" + | } ] + | }, "null" ] + |}, "null" ] + """.stripMargin + val readBackOne = dfOne.select(to_avro($"array").as("avro")) + .select(from_avro($"avro", avroTypeArrStruct).as("array")) + checkAnswer(dfOne, readBackOne) + } + + test("SPARK-27798: from_avro produces same value when converted to local relation") { + val simpleSchema = + """ + |{ + | "type": "record", + | "name" : "Payload", + | "fields" : [ {"name" : "message", "type" : "string" } ] + |} + """.stripMargin + + def generateBinary(message: String, avroSchema: String): Array[Byte] = { + val schema = new Schema.Parser().parse(avroSchema) + val out = new ByteArrayOutputStream() + val writer = new GenericDatumWriter[GenericRecord](schema) + val encoder = EncoderFactory.get().binaryEncoder(out, null) + val rootRecord = new GenericRecordBuilder(schema).set("message", message).build() + writer.write(rootRecord, encoder) + encoder.flush() + out.toByteArray + } + + // This bug is hit when the rule `ConvertToLocalRelation` is run. But the rule was excluded + // in `SharedSparkSession`. + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "") { + val df = Seq("one", "two", "three", "four").map(generateBinary(_, simpleSchema)) + .toDF() + .withColumn("value", from_avro(col("value"), simpleSchema)) + + assert(df.queryExecution.executedPlan.isInstanceOf[LocalTableScanExec]) + assert(df.collect().map(_.get(0)) === Seq(Row("one"), Row("two"), Row("three"), Row("four"))) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 817387b2845f..6ffe133ee652 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -19,12 +19,10 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator -import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType, StructType} - object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { override def zero: (Long, Long) = (0, 0) override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { @@ -226,25 +224,6 @@ class DatasetAggregatorSuite extends QueryTest with SharedSparkSession { private implicit val ordering = Ordering.by((c: AggData) => c.a -> c.b) - test("typed aggregation: TypedAggregator") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - - checkDataset( - ds.groupByKey(_._1).agg(typed.sum(_._2)), - ("a", 30.0), ("b", 3.0), ("c", 1.0)) - } - - test("typed aggregation: TypedAggregator, expr, expr") { - val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() - - checkDataset( - ds.groupByKey(_._1).agg( - typed.sum(_._2), - expr("sum(_2)").as[Long], - count("*")), - ("a", 30.0, 30L, 2L), ("b", 3.0, 3L, 2L), ("c", 1.0, 1L, 1L)) - } - test("typed aggregation: complex result type") { val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() @@ -255,17 +234,6 @@ class DatasetAggregatorSuite extends QueryTest with SharedSparkSession { ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) } - test("typed aggregation: in project list") { - val ds = Seq(1, 3, 2, 5).toDS() - - checkDataset( - ds.select(typed.sum((i: Int) => i)), - 11.0) - checkDataset( - ds.select(typed.sum((i: Int) => i), typed.sum((i: Int) => i * 2)), - 11.0 -> 22.0) - } - test("typed aggregation: class input") { val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() @@ -315,14 +283,6 @@ class DatasetAggregatorSuite extends QueryTest with SharedSparkSession { ("one", 1), ("two", 1)) } - test("typed aggregate: avg, count, sum") { - val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() - checkDataset( - ds.groupByKey(_._1).agg( - typed.avg(_._2), typed.count(_._2), typed.sum(_._2), typed.sumLong(_._2)), - ("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L)) - } - test("generic typed sum") { val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() checkDataset( @@ -366,18 +326,6 @@ class DatasetAggregatorSuite extends QueryTest with SharedSparkSession { checkAnswer(df2.agg(RowAgg.toColumn as "b").select("b"), Row(6) :: Nil) } - test("spark-15114 shorter system generated alias names") { - val ds = Seq(1, 3, 2, 5).toDS() - assert(ds.select(typed.sum((i: Int) => i)).columns.head === "TypedSumDouble(int)") - val ds2 = ds.select(typed.sum((i: Int) => i), typed.avg((i: Int) => i)) - assert(ds2.columns.head === "TypedSumDouble(int)") - assert(ds2.columns.last === "TypedAverage(int)") - val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") - assert(df.groupBy($"j").agg(RowAgg.toColumn).columns.last == - "RowAgg(org.apache.spark.sql.Row)") - assert(df.groupBy($"j").agg(RowAgg.toColumn as "agg1").columns.last == "agg1") - } - test("SPARK-15814 Aggregator can return null result") { val ds = Seq(AggData(1, "one"), AggData(2, "two")).toDS() checkDatasetUnorderly( @@ -390,15 +338,6 @@ class DatasetAggregatorSuite extends QueryTest with SharedSparkSession { checkDataset(ds.select(MapTypeBufferAgg.toColumn), 1) } - test("SPARK-15204 improve nullability inference for Aggregator") { - val ds1 = Seq(1, 3, 2, 5).toDS() - assert(ds1.select(typed.sum((i: Int) => i)).schema.head.nullable === false) - val ds2 = Seq(AggData(1, "a"), AggData(2, "a")).toDS() - assert(ds2.select(SeqAgg.toColumn).schema.head.nullable) - val ds3 = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData] - assert(ds3.select(NameAgg.toColumn).schema.head.nullable) - } - test("SPARK-18147: very complex aggregator result type") { val df = Seq(1 -> "a", 2 -> "b", 2 -> "c").toDF("i", "j") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index c80e675b149d..7b6b93549667 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.sql.types.DoubleType import org.apache.spark.unsafe.types.CalendarInterval class DateFunctionsSuite extends QueryTest with SharedSparkSession { @@ -704,91 +704,6 @@ class DateFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(df.selectExpr("datediff(a, d)"), Seq(Row(1), Row(1))) } - test("from_utc_timestamp with literal zone") { - val df = Seq( - (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), - (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") - ).toDF("a", "b") - withSQLConf(SQLConf.UTC_TIMESTAMP_FUNC_ENABLED.key -> "true") { - checkAnswer( - df.select(from_utc_timestamp(col("a"), "PST")), - Seq( - Row(Timestamp.valueOf("2015-07-23 17:00:00")), - Row(Timestamp.valueOf("2015-07-24 17:00:00")))) - checkAnswer( - df.select(from_utc_timestamp(col("b"), "PST")), - Seq( - Row(Timestamp.valueOf("2015-07-23 17:00:00")), - Row(Timestamp.valueOf("2015-07-24 17:00:00")))) - } - val msg = intercept[AnalysisException] { - df.select(from_utc_timestamp(col("a"), "PST")).collect() - }.getMessage - assert(msg.contains(SQLConf.UTC_TIMESTAMP_FUNC_ENABLED.key)) - } - - test("from_utc_timestamp with column zone") { - withSQLConf(SQLConf.UTC_TIMESTAMP_FUNC_ENABLED.key -> "true") { - val df = Seq( - (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", "CET"), - (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", "PST") - ).toDF("a", "b", "c") - checkAnswer( - df.select(from_utc_timestamp(col("a"), col("c"))), - Seq( - Row(Timestamp.valueOf("2015-07-24 02:00:00")), - Row(Timestamp.valueOf("2015-07-24 17:00:00")))) - checkAnswer( - df.select(from_utc_timestamp(col("b"), col("c"))), - Seq( - Row(Timestamp.valueOf("2015-07-24 02:00:00")), - Row(Timestamp.valueOf("2015-07-24 17:00:00")))) - } - } - - test("to_utc_timestamp with literal zone") { - val df = Seq( - (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), - (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") - ).toDF("a", "b") - withSQLConf(SQLConf.UTC_TIMESTAMP_FUNC_ENABLED.key -> "true") { - checkAnswer( - df.select(to_utc_timestamp(col("a"), "PST")), - Seq( - Row(Timestamp.valueOf("2015-07-24 07:00:00")), - Row(Timestamp.valueOf("2015-07-25 07:00:00")))) - checkAnswer( - df.select(to_utc_timestamp(col("b"), "PST")), - Seq( - Row(Timestamp.valueOf("2015-07-24 07:00:00")), - Row(Timestamp.valueOf("2015-07-25 07:00:00")))) - } - val msg = intercept[AnalysisException] { - df.select(to_utc_timestamp(col("a"), "PST")).collect() - }.getMessage - assert(msg.contains(SQLConf.UTC_TIMESTAMP_FUNC_ENABLED.key)) - } - - test("to_utc_timestamp with column zone") { - withSQLConf(SQLConf.UTC_TIMESTAMP_FUNC_ENABLED.key -> "true") { - val df = Seq( - (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", "PST"), - (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", "CET") - ).toDF("a", "b", "c") - checkAnswer( - df.select(to_utc_timestamp(col("a"), col("c"))), - Seq( - Row(Timestamp.valueOf("2015-07-24 07:00:00")), - Row(Timestamp.valueOf("2015-07-24 22:00:00")))) - checkAnswer( - df.select(to_utc_timestamp(col("b"), col("c"))), - Seq( - Row(Timestamp.valueOf("2015-07-24 07:00:00")), - Row(Timestamp.valueOf("2015-07-24 22:00:00")))) - } - } - - test("to_timestamp with microseconds precision") { withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") { val timestamp = "1970-01-01T00:00:00.123456Z" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedDatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedDatasetAggregatorSuite.scala new file mode 100644 index 000000000000..b1d5e80f8563 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedDatasetAggregatorSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.expressions.scalalang.typed +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSparkSession + +@deprecated("This test suite will be removed.", "3.0.0") +class DeprecatedDatasetAggregatorSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("typed aggregation: TypedAggregator") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDataset( + ds.groupByKey(_._1).agg(typed.sum(_._2)), + ("a", 30.0), ("b", 3.0), ("c", 1.0)) + } + + test("typed aggregation: TypedAggregator, expr, expr") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + + checkDataset( + ds.groupByKey(_._1).agg( + typed.sum(_._2), + expr("sum(_2)").as[Long], + count("*")), + ("a", 30.0, 30L, 2L), ("b", 3.0, 3L, 2L), ("c", 1.0, 1L, 1L)) + } + + test("typed aggregation: in project list") { + val ds = Seq(1, 3, 2, 5).toDS() + + checkDataset( + ds.select(typed.sum((i: Int) => i)), + 11.0) + checkDataset( + ds.select(typed.sum((i: Int) => i), typed.sum((i: Int) => i * 2)), + 11.0 -> 22.0) + } + + test("typed aggregate: avg, count, sum") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + checkDataset( + ds.groupByKey(_._1).agg( + typed.avg(_._2), typed.count(_._2), typed.sum(_._2), typed.sumLong(_._2)), + ("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L)) + } + + test("spark-15114 shorter system generated alias names") { + val ds = Seq(1, 3, 2, 5).toDS() + assert(ds.select(typed.sum((i: Int) => i)).columns.head === "TypedSumDouble(int)") + val ds2 = ds.select(typed.sum((i: Int) => i), typed.avg((i: Int) => i)) + assert(ds2.columns.head === "TypedSumDouble(int)") + assert(ds2.columns.last === "TypedAverage(int)") + val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") + assert(df.groupBy($"j").agg(RowAgg.toColumn).columns.last == + "RowAgg(org.apache.spark.sql.Row)") + assert(df.groupBy($"j").agg(RowAgg.toColumn as "agg1").columns.last == "agg1") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedDateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedDateFunctionsSuite.scala new file mode 100644 index 000000000000..bef83ee199cc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedDateFunctionsSuite.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.Timestamp + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession + +@deprecated("This test suite will be removed.", "3.0.0") +class DeprecatedDateFunctionsSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("from_utc_timestamp with literal zone") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") + ).toDF("a", "b") + withSQLConf(SQLConf.UTC_TIMESTAMP_FUNC_ENABLED.key -> "true") { + checkAnswer( + df.select(from_utc_timestamp(col("a"), "PST")), + Seq( + Row(Timestamp.valueOf("2015-07-23 17:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + checkAnswer( + df.select(from_utc_timestamp(col("b"), "PST")), + Seq( + Row(Timestamp.valueOf("2015-07-23 17:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + } + val msg = intercept[AnalysisException] { + df.select(from_utc_timestamp(col("a"), "PST")).collect() + }.getMessage + assert(msg.contains(SQLConf.UTC_TIMESTAMP_FUNC_ENABLED.key)) + } + + test("from_utc_timestamp with column zone") { + withSQLConf(SQLConf.UTC_TIMESTAMP_FUNC_ENABLED.key -> "true") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", "CET"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", "PST") + ).toDF("a", "b", "c") + checkAnswer( + df.select(from_utc_timestamp(col("a"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 02:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + checkAnswer( + df.select(from_utc_timestamp(col("b"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 02:00:00")), + Row(Timestamp.valueOf("2015-07-24 17:00:00")))) + } + } + + test("to_utc_timestamp with literal zone") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00") + ).toDF("a", "b") + withSQLConf(SQLConf.UTC_TIMESTAMP_FUNC_ENABLED.key -> "true") { + checkAnswer( + df.select(to_utc_timestamp(col("a"), "PST")), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-25 07:00:00")))) + checkAnswer( + df.select(to_utc_timestamp(col("b"), "PST")), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-25 07:00:00")))) + } + val msg = intercept[AnalysisException] { + df.select(to_utc_timestamp(col("a"), "PST")).collect() + }.getMessage + assert(msg.contains(SQLConf.UTC_TIMESTAMP_FUNC_ENABLED.key)) + } + + test("to_utc_timestamp with column zone") { + withSQLConf(SQLConf.UTC_TIMESTAMP_FUNC_ENABLED.key -> "true") { + val df = Seq( + (Timestamp.valueOf("2015-07-24 00:00:00"), "2015-07-24 00:00:00", "PST"), + (Timestamp.valueOf("2015-07-25 00:00:00"), "2015-07-25 00:00:00", "CET") + ).toDF("a", "b", "c") + checkAnswer( + df.select(to_utc_timestamp(col("a"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-24 22:00:00")))) + checkAnswer( + df.select(to_utc_timestamp(col("b"), col("c"))), + Seq( + Row(Timestamp.valueOf("2015-07-24 07:00:00")), + Row(Timestamp.valueOf("2015-07-24 22:00:00")))) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index ac7976090ef8..068ea05ead35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -27,7 +27,7 @@ import org.mockito.Mockito._ import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, GenericRow, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, SortExec, SparkPlan} import org.apache.spark.sql.execution.joins._ @@ -238,7 +238,9 @@ class JoinSuite extends QueryTest with SharedSparkSession { checkAnswer( bigDataX.join(bigDataY).where($"x.key" === $"y.key"), - testData.rdd.flatMap(row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) + testData.rdd.flatMap { row => + Seq.fill(16)(new GenericRow(Seq(row, row).flatMap(_.toSeq).toArray)) + }.collect().toSeq) } test("cartesian product join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 37d98f7c8742..06309bfef7e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -26,6 +26,7 @@ import scala.collection.parallel.immutable.ParVector import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.HiveResult.hiveResultString @@ -783,8 +784,9 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { | SELECT * FROM testData UNION ALL | SELECT * FROM testData) y |WHERE x.key = y.key""".stripMargin), - testData.rdd.flatMap( - row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) + testData.rdd.flatMap { row => + Seq.fill(16)(new GenericRow(Seq(row, row).flatMap(_.toSeq).toArray)) + }.collect().toSeq) } test("cartesian product join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DeprecatedWholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DeprecatedWholeStageCodegenSuite.scala new file mode 100644 index 000000000000..c198978f5888 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DeprecatedWholeStageCodegenSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.expressions.scalalang.typed +import org.apache.spark.sql.test.SharedSparkSession + +@deprecated("This test suite will be removed.", "3.0.0") +class DeprecatedWholeStageCodegenSuite extends QueryTest with SharedSparkSession { + + test("simple typed UDAF should be included in WholeStageCodegen") { + import testImplicits._ + + val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS() + .groupByKey(_._1).agg(typed.sum(_._2)) + + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) + assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index afe9eb5c151d..572932fc2750 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -23,7 +23,6 @@ import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec -import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -107,19 +106,6 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession { assert(ds.collect() === Array(0, 6)) } - test("simple typed UDAF should be included in WholeStageCodegen") { - import testImplicits._ - - val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS() - .groupByKey(_._1).agg(typed.sum(_._2)) - - val plan = ds.queryExecution.executedPlan - assert(plan.find(p => - p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined) - assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) - } - test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeprecatedStreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeprecatedStreamingAggregationSuite.scala new file mode 100644 index 000000000000..99f7e32d4df7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeprecatedStreamingAggregationSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.Assertions + +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.state.StreamingAggregationStateManager +import org.apache.spark.sql.expressions.scalalang.typed +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.OutputMode._ + +@deprecated("This test suite will be removed.", "3.0.0") +class DeprecatedStreamingAggregationSuite extends StateStoreMetricsTest with Assertions { + + import testImplicits._ + + def executeFuncWithStateVersionSQLConf( + stateVersion: Int, + confPairs: Seq[(String, String)], + func: => Any): Unit = { + withSQLConf(confPairs ++ + Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString): _*) { + func + } + } + + def testWithAllStateVersions(name: String, confPairs: (String, String)*) + (func: => Any): Unit = { + for (version <- StreamingAggregationStateManager.supportedVersions) { + test(s"$name - state format version $version") { + executeFuncWithStateVersionSQLConf(version, confPairs, func) + } + } + } + + + testWithAllStateVersions("typed aggregators") { + val inputData = MemoryStream[(String, Int)] + val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2)) + + testStream(aggregated, Update)( + AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)), + CheckLastBatch(("a", 30), ("b", 3), ("c", 1)) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 6dbf4ff283af..9779635df3ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemorySink import org.apache.spark.sql.execution.streaming.state.StreamingAggregationStateManager -import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ @@ -280,16 +279,6 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { ) } - testWithAllStateVersions("typed aggregators") { - val inputData = MemoryStream[(String, Int)] - val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2)) - - testStream(aggregated, Update)( - AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)), - CheckLastBatch(("a", 30), ("b", 3), ("c", 1)) - ) - } - testWithAllStateVersions("prune results by current_time, complete mode") { import testImplicits._ val clock = new StreamManualClock