diff --git a/core/src/test/scala/org/apache/spark/metrics/sink/PrometheusServletSuite.scala b/core/src/test/scala/org/apache/spark/metrics/sink/PrometheusServletSuite.scala index 080ca0e41f79..260a4fc50b3e 100644 --- a/core/src/test/scala/org/apache/spark/metrics/sink/PrometheusServletSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/sink/PrometheusServletSuite.scala @@ -63,7 +63,7 @@ class PrometheusServletSuite extends SparkFunSuite with PrivateMethodTester { val key = "local-1592132938718.driver.LiveListenerBus." + "listenerProcessingTime.org.apache.spark.HeartbeatReceiver" val sink = createPrometheusServlet() - val suffix = sink invokePrivate PrivateMethod[String]('normalizeKey)(key) + val suffix = sink invokePrivate PrivateMethod[String](Symbol("normalizeKey"))(key) assert(suffix == "metrics_local_1592132938718_driver_LiveListenerBus_" + "listenerProcessingTime_org_apache_spark_HeartbeatReceiver_") } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala index 5510f0019353..d5a0e6c998eb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SimpleTypedAggregator.scala @@ -31,7 +31,7 @@ object SimpleTypedAggregator { .getOrCreate() import spark.implicits._ - val ds = spark.range(20).select(('id % 3).as("key"), 'id).as[(Long, Long)] + val ds = spark.range(20).select((Symbol("id") % 3).as("key"), Symbol("id")).as[(Long, Long)] println("input data:") ds.show() diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala index c9e0d4344691..5cf88602c681 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -36,11 +36,11 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { import testImplicits._ test("roundtrip in to_avro and from_avro - int and string") { - val df = spark.range(10).select('id, 'id.cast("string").as("str")) + val df = spark.range(10).select(Symbol("id"), Symbol("id").cast("string").as("str")) val avroDF = df.select( - functions.to_avro('id).as("a"), - functions.to_avro('str).as("b")) + functions.to_avro(Symbol("id")).as("a"), + functions.to_avro(Symbol("str")).as("b")) val avroTypeLong = s""" |{ | "type": "int", @@ -54,13 +54,14 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { |} """.stripMargin checkAnswer(avroDF.select( - functions.from_avro('a, avroTypeLong), - functions.from_avro('b, avroTypeStr)), df) + functions.from_avro(Symbol("a"), avroTypeLong), + functions.from_avro(Symbol("b"), avroTypeStr)), df) } test("roundtrip in to_avro and from_avro - struct") { - val df = spark.range(10).select(struct('id, 'id.cast("string").as("str")).as("struct")) - val avroStructDF = df.select(functions.to_avro('struct).as("avro")) + val df = spark.range(10).select( + struct(Symbol("id"), Symbol("id").cast("string").as("str")).as("struct")) + val avroStructDF = df.select(functions.to_avro(Symbol("struct")).as("avro")) val avroTypeStruct = s""" |{ | "type": "record", @@ -72,13 +73,13 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { |} """.stripMargin checkAnswer(avroStructDF.select( - functions.from_avro('avro, avroTypeStruct)), df) + functions.from_avro(Symbol("avro"), avroTypeStruct)), df) } test("handle invalid input in from_avro") { val count = 10 - val df = spark.range(count).select(struct('id, 'id.as("id2")).as("struct")) - val avroStructDF = df.select(functions.to_avro('struct).as("avro")) + val df = spark.range(count).select(struct(Symbol("id"), Symbol("id").as("id2")).as("struct")) + val avroStructDF = df.select(functions.to_avro(Symbol("struct")).as("avro")) val avroTypeStruct = s""" |{ | "type": "record", @@ -93,7 +94,7 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { intercept[SparkException] { avroStructDF.select( functions.from_avro( - 'avro, avroTypeStruct, Map("mode" -> "FAILFAST").asJava)).collect() + Symbol("avro"), avroTypeStruct, Map("mode" -> "FAILFAST").asJava)).collect() } // For PERMISSIVE mode, the result should be row of null columns. @@ -101,7 +102,7 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer( avroStructDF.select( functions.from_avro( - 'avro, avroTypeStruct, Map("mode" -> "PERMISSIVE").asJava)), + Symbol("avro"), avroTypeStruct, Map("mode" -> "PERMISSIVE").asJava)), expected) } @@ -161,9 +162,9 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { test("SPARK-27506: roundtrip in to_avro and from_avro with different compatible schemas") { val df = spark.range(10).select( - struct('id.as("col1"), 'id.cast("string").as("col2")).as("struct") + struct(Symbol("id").as("col1"), Symbol("id").cast("string").as("col2")).as("struct") ) - val avroStructDF = df.select(functions.to_avro('struct).as("avro")) + val avroStructDF = df.select(functions.to_avro(Symbol("struct")).as("avro")) val actualAvroSchema = s""" |{ @@ -190,20 +191,24 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { |""".stripMargin val expected = spark.range(10).select( - struct('id.as("col1"), 'id.cast("string").as("col2"), lit("").as("col3")).as("struct") + struct( + Symbol("id").as("col1"), + Symbol("id").cast("string").as("col2"), + lit("").as("col3")).as("struct") ) checkAnswer( avroStructDF.select( functions.from_avro( - 'avro, + Symbol("avro"), actualAvroSchema, Map("avroSchema" -> evolvedAvroSchema).asJava)), expected) } test("roundtrip in to_avro and from_avro - struct with nullable Avro schema") { - val df = spark.range(10).select(struct('id, 'id.cast("string").as("str")).as("struct")) + val df = spark.range(10).select( + struct(Symbol("id"), Symbol("id").cast("string").as("str")).as("struct")) val avroTypeStruct = s""" |{ | "type": "record", @@ -214,13 +219,14 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { | ] |} """.stripMargin - val avroStructDF = df.select(functions.to_avro('struct, avroTypeStruct).as("avro")) + val avroStructDF = df.select(functions.to_avro(Symbol("struct"), avroTypeStruct).as("avro")) checkAnswer(avroStructDF.select( - functions.from_avro('avro, avroTypeStruct)), df) + functions.from_avro(Symbol("avro"), avroTypeStruct)), df) } test("to_avro with unsupported nullable Avro schema") { - val df = spark.range(10).select(struct('id, 'id.cast("string").as("str")).as("struct")) + val df = spark.range(10).select( + struct(Symbol("id"), Symbol("id").cast("string").as("str")).as("struct")) for (unsupportedAvroType <- Seq("""["null", "int", "long"]""", """["int", "long"]""")) { val avroTypeStruct = s""" |{ @@ -233,7 +239,7 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { |} """.stripMargin val message = intercept[SparkException] { - df.select(functions.to_avro('struct, avroTypeStruct).as("avro")).show() + df.select(functions.to_avro(Symbol("struct"), avroTypeStruct).as("avro")).show() }.getCause.getMessage assert(message.contains("Only UNION of a null type and a non-null type is supported")) } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala index 12ebddf72b03..fc3a18197160 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala @@ -122,7 +122,7 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession { withTempDir { dir => val expected = timestampInputData.map(t => Row(new Timestamp(t._1))) val timestampAvro = timestampFile(dir.getAbsolutePath) - val df = spark.read.format("avro").load(timestampAvro).select('timestamp_millis) + val df = spark.read.format("avro").load(timestampAvro).select(Symbol("timestamp_millis")) checkAnswer(df, expected) @@ -137,7 +137,7 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession { withTempDir { dir => val expected = timestampInputData.map(t => Row(new Timestamp(t._2))) val timestampAvro = timestampFile(dir.getAbsolutePath) - val df = spark.read.format("avro").load(timestampAvro).select('timestamp_micros) + val df = spark.read.format("avro").load(timestampAvro).select(Symbol("timestamp_micros")) checkAnswer(df, expected) @@ -151,8 +151,8 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession { test("Logical type: user specified output schema with different timestamp types") { withTempDir { dir => val timestampAvro = timestampFile(dir.getAbsolutePath) - val df = - spark.read.format("avro").load(timestampAvro).select('timestamp_millis, 'timestamp_micros) + val df = spark.read.format("avro").load(timestampAvro).select( + Symbol("timestamp_millis"), Symbol("timestamp_micros")) val expected = timestampInputData.map(t => Row(new Timestamp(t._1), new Timestamp(t._2))) @@ -184,7 +184,7 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession { withTempDir { dir => val timestampAvro = timestampFile(dir.getAbsolutePath) val schema = StructType(StructField("long", TimestampType, true) :: Nil) - val df = spark.read.format("avro").schema(schema).load(timestampAvro).select('long) + val df = spark.read.format("avro").schema(schema).load(timestampAvro).select(Symbol("long")) val expected = timestampInputData.map(t => Row(new Timestamp(t._3))) diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/DeprecatedAvroFunctionsSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/DeprecatedAvroFunctionsSuite.scala index cdfa1b118b18..4eb94afdfc09 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/DeprecatedAvroFunctionsSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/DeprecatedAvroFunctionsSuite.scala @@ -34,9 +34,9 @@ class DeprecatedAvroFunctionsSuite extends QueryTest with SharedSparkSession { import testImplicits._ test("roundtrip in to_avro and from_avro - int and string") { - val df = spark.range(10).select('id, 'id.cast("string").as("str")) + val df = spark.range(10).select(Symbol("id"), Symbol("id").cast("string").as("str")) - val avroDF = df.select(to_avro('id).as("a"), to_avro('str).as("b")) + val avroDF = df.select(to_avro(Symbol("id")).as("a"), to_avro(Symbol("str")).as("b")) val avroTypeLong = s""" |{ | "type": "int", @@ -49,12 +49,14 @@ class DeprecatedAvroFunctionsSuite extends QueryTest with SharedSparkSession { | "name": "str" |} """.stripMargin - checkAnswer(avroDF.select(from_avro('a, avroTypeLong), from_avro('b, avroTypeStr)), df) + checkAnswer( + avroDF.select(from_avro(Symbol("a"), avroTypeLong), from_avro(Symbol("b"), avroTypeStr)), df) } test("roundtrip in to_avro and from_avro - struct") { - val df = spark.range(10).select(struct('id, 'id.cast("string").as("str")).as("struct")) - val avroStructDF = df.select(to_avro('struct).as("avro")) + val df = spark.range(10).select( + struct(Symbol("id"), Symbol("id").cast("string").as("str")).as("struct")) + val avroStructDF = df.select(to_avro(Symbol("struct")).as("avro")) val avroTypeStruct = s""" |{ | "type": "record", @@ -65,7 +67,7 @@ class DeprecatedAvroFunctionsSuite extends QueryTest with SharedSparkSession { | ] |} """.stripMargin - checkAnswer(avroStructDF.select(from_avro('avro, avroTypeStruct)), df) + checkAnswer(avroStructDF.select(from_avro(Symbol("avro"), avroTypeStruct)), df) } test("roundtrip in to_avro and from_avro - array with null") { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 62ba459070c2..9e8ee1e9a76e 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -554,9 +554,9 @@ abstract class KafkaMicroBatchSourceSuiteBase extends KafkaSourceSuiteBase { val windowedAggregation = kafka .withWatermark("timestamp", "10 seconds") - .groupBy(window($"timestamp", "5 seconds") as 'window) - .agg(count("*") as 'count) - .select($"window".getField("start") as 'window, $"count") + .groupBy(window($"timestamp", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) + .select($"window".getField("start") as Symbol("window"), $"count") val query = windowedAggregation .writeStream diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 4e808a5277a9..f54eff90a5e0 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -524,7 +524,7 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase { test("SPARK-20496: batch - enforce analyzed plans") { val inputEvents = spark.range(1, 1000) - .select(to_json(struct("*")) as 'value) + .select(to_json(struct("*")) as Symbol("value")) val topic = newTopic() testUtils.createTopic(topic) diff --git a/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala index 21b823383d23..7a9ea30793a1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/FunctionsSuite.scala @@ -34,7 +34,7 @@ class FunctionsSuite extends MLTest { (Vectors.sparse(3, Seq((0, 2.0), (2, 3.0))), OldVectors.sparse(3, Seq((0, 20.0), (2, 30.0)))) ).toDF("vec", "oldVec") - val result = df.select(vector_to_array('vec), vector_to_array('oldVec)) + val result = df.select(vector_to_array(Symbol("vec")), vector_to_array(Symbol("oldVec"))) .as[(Seq[Double], Seq[Double])].collect().toSeq val expected = Seq( @@ -65,7 +65,8 @@ class FunctionsSuite extends MLTest { (Vectors.sparse(3, Seq((0, 2.0), (2, 3.0))), OldVectors.sparse(3, Seq((0, 20.0), (2, 30.0)))) ).toDF("vec", "oldVec") val dfArrayFloat = df3.select( - vector_to_array('vec, dtype = "float32"), vector_to_array('oldVec, dtype = "float32")) + vector_to_array(Symbol("vec"), dtype = "float32"), + vector_to_array(Symbol("oldVec"), dtype = "float32")) // Check values are correct val result3 = dfArrayFloat.as[(Seq[Float], Seq[Float])].collect().toSeq @@ -82,7 +83,8 @@ class FunctionsSuite extends MLTest { val thrown2 = intercept[IllegalArgumentException] { df3.select( - vector_to_array('vec, dtype = "float16"), vector_to_array('oldVec, dtype = "float16")) + vector_to_array(Symbol("vec"), dtype = "float16"), + vector_to_array(Symbol("oldVec"), dtype = "float16")) } assert(thrown2.getMessage.contains( s"Unsupported dtype: float16. Valid values: float64, float32.")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 0e0142eb7689..5af9337b569c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -52,7 +52,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { test("fail for unresolved plan") { intercept[AnalysisException] { // `testRelation` does not have column `b`. - testRelation.select('b).analyze + testRelation.select(Symbol("b")).analyze } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala index 72e10eadf79f..6aee9ea3b8dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/PullOutNondeterministicSuite.scala @@ -27,10 +27,10 @@ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation */ class PullOutNondeterministicSuite extends AnalysisTest { - private lazy val a = 'a.int - private lazy val b = 'b.int + private lazy val a = Symbol("a").int + private lazy val b = Symbol("b").int private lazy val r = LocalRelation(a, b) - private lazy val rnd = Rand(10).as('_nondeterministic) + private lazy val rnd = Rand(10).as(Symbol("_nondeterministic")) private lazy val rndref = rnd.toAttribute test("no-op on filter") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala index cdfae1413829..b31233c3117e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupingAnalyticsSuite.scala @@ -27,14 +27,14 @@ import org.apache.spark.sql.types._ class ResolveGroupingAnalyticsSuite extends AnalysisTest { - lazy val a = 'a.int - lazy val b = 'b.string - lazy val c = 'c.string + lazy val a = Symbol("a").int + lazy val b = Symbol("b").string + lazy val c = Symbol("c").string lazy val unresolved_a = UnresolvedAttribute("a") lazy val unresolved_b = UnresolvedAttribute("b") lazy val unresolved_c = UnresolvedAttribute("c") - lazy val gid = 'spark_grouping_id.long.withNullability(false) - lazy val hive_gid = 'grouping__id.long.withNullability(false) + lazy val gid = Symbol("spark_grouping_id").long.withNullability(false) + lazy val hive_gid = Symbol("grouping__id").long.withNullability(false) lazy val grouping_a = Cast(ShiftRight(gid, 1) & 1L, ByteType, Option(TimeZone.getDefault().getID)) lazy val nulInt = Literal(null, IntegerType) lazy val nulStr = Literal(null, StringType) @@ -287,7 +287,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b))) val expected = Project(Seq(a, b), Sort( - Seq(SortOrder('aggOrder.byte.withNullability(false), Ascending)), true, + Seq(SortOrder(Symbol("aggOrder").byte.withNullability(false), Ascending)), true, Aggregate(Seq(a, b, gid), Seq(a, b, grouping_a.as("aggOrder")), Expand( @@ -308,7 +308,7 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest { GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)), Seq(unresolved_a, unresolved_b), r1, Seq(unresolved_a, unresolved_b))) val expected3 = Project(Seq(a, b), Sort( - Seq(SortOrder('aggOrder.long.withNullability(false), Ascending)), true, + Seq(SortOrder(Symbol("aggOrder").long.withNullability(false), Ascending)), true, Aggregate(Seq(a, b, gid), Seq(a, b, gid.as("aggOrder")), Expand( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala index 513f1d001f75..cd1b51b61e50 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala @@ -71,9 +71,9 @@ class ResolveHintsSuite extends AnalysisTest { test("do not traverse past existing broadcast hints") { checkAnalysis( UnresolvedHint("MAPJOIN", Seq("table"), - ResolvedHint(table("table").where('a > 1), HintInfo(strategy = Some(BROADCAST)))), - ResolvedHint(testRelation.where('a > 1), HintInfo(strategy = Some(BROADCAST))).analyze, - caseSensitive = false) + ResolvedHint(table("table").where(Symbol("a") > 1), HintInfo(strategy = Some(BROADCAST)))), + ResolvedHint(testRelation.where( + Symbol("a") > 1), HintInfo(strategy = Some(BROADCAST))).analyze, caseSensitive = false) } test("should work for subqueries") { @@ -83,7 +83,7 @@ class ResolveHintsSuite extends AnalysisTest { caseSensitive = false) checkAnalysis( - UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)), + UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery(Symbol("tableAlias"))), ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))), caseSensitive = false) @@ -96,8 +96,10 @@ class ResolveHintsSuite extends AnalysisTest { test("do not traverse past subquery alias") { checkAnalysis( - UnresolvedHint("MAPJOIN", Seq("table"), table("table").where('a > 1).subquery('tableAlias)), - testRelation.where('a > 1).analyze, + UnresolvedHint( + "MAPJOIN", Seq("table"), + table("table").where(Symbol("a") > 1).subquery(Symbol("tableAlias"))), + testRelation.where(Symbol("a") > 1).analyze, caseSensitive = false) } @@ -109,8 +111,9 @@ class ResolveHintsSuite extends AnalysisTest { |SELECT /*+ BROADCAST(ctetable) */ * FROM ctetable """.stripMargin ), - ResolvedHint(testRelation.where('a > 1).select('a), HintInfo(strategy = Some(BROADCAST))) - .select('a).analyze, + ResolvedHint( + testRelation.where(Symbol("a") > 1).select(Symbol("a")), + HintInfo(strategy = Some(BROADCAST))).select(Symbol("a")).analyze, caseSensitive = false) } @@ -122,7 +125,7 @@ class ResolveHintsSuite extends AnalysisTest { |SELECT /*+ BROADCAST(table) */ * FROM ctetable """.stripMargin ), - testRelation.where('a > 1).select('a).select('a).analyze, + testRelation.where(Symbol("a") > 1).select(Symbol("a")).select(Symbol("a")).analyze, caseSensitive = false) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala index b9233a27f3d7..e1991d4491a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala @@ -35,9 +35,9 @@ class ResolveLambdaVariablesSuite extends PlanTest { val batches = Batch("Resolution", FixedPoint(4), ResolveLambdaVariables) :: Nil } - private val key = 'key.int - private val values1 = 'values1.array(IntegerType) - private val values2 = 'values2.array(ArrayType(ArrayType(IntegerType))) + private val key = Symbol("key").int + private val values1 = Symbol("values1").array(IntegerType) + private val values2 = Symbol("values2").array(ArrayType(ArrayType(IntegerType))) private val data = LocalRelation(Seq(key, values1, values2)) private val lvInt = NamedLambdaVariable("x", IntegerType, nullable = true) private val lvHiddenInt = NamedLambdaVariable("col0", IntegerType, nullable = true) @@ -56,14 +56,19 @@ class ResolveLambdaVariablesSuite extends PlanTest { } test("resolution - simple") { - val in = ArrayTransform(values1, LambdaFunction(lv('x) + 1, lv('x) :: Nil)) + val in = ArrayTransform(values1, LambdaFunction(lv(Symbol("x")) + 1, lv(Symbol("x")) :: Nil)) val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil)) checkExpression(in, out) } test("resolution - nested") { - val in = ArrayTransform(values2, LambdaFunction( - ArrayTransform(lv('x), LambdaFunction(lv('x) + 1, lv('x) :: Nil)), lv('x) :: Nil)) + val in = ArrayTransform( + values2, + LambdaFunction( + ArrayTransform( + lv(Symbol("x")), + LambdaFunction(lv(Symbol("x")) + 1, lv(Symbol("x")) :: Nil)), + lv(Symbol("x")) :: Nil)) val out = ArrayTransform(values2, LambdaFunction( ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), lvArray :: Nil)) checkExpression(in, out) @@ -77,14 +82,16 @@ class ResolveLambdaVariablesSuite extends PlanTest { test("fail - name collisions") { val p = plan(ArrayTransform(values1, - LambdaFunction(lv('x) + lv('X), lv('x) :: lv('X) :: Nil))) + LambdaFunction(lv(Symbol("x")) + lv(Symbol("X")), lv(Symbol("x")) :: lv(Symbol("X")) :: Nil))) val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage assert(msg.contains("arguments should not have names that are semantically the same")) } test("fail - lambda arguments") { val p = plan(ArrayTransform(values1, - LambdaFunction(lv('x) + lv('y) + lv('z), lv('x) :: lv('y) :: lv('z) :: Nil))) + LambdaFunction( + lv(Symbol("x")) + lv(Symbol("y")) + lv(Symbol("z")), + lv(Symbol("x")) :: lv(Symbol("y")) :: lv(Symbol("z")) :: Nil))) val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage assert(msg.contains("does not match the number of arguments expected")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala index ea2284e5420b..27d37ade7efe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala @@ -24,10 +24,10 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation class ResolveNaturalJoinSuite extends AnalysisTest { - lazy val a = 'a.string - lazy val b = 'b.string - lazy val c = 'c.string - lazy val d = 'd.struct('f1.int, 'f2.long) + lazy val a = Symbol("a").string + lazy val b = Symbol("b").string + lazy val c = Symbol("c").string + lazy val d = Symbol("d").struct(Symbol("f1").int, Symbol("f2").long) lazy val aNotNull = a.notNull lazy val bNotNull = b.notNull lazy val cNotNull = c.notNull diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala index 5aa80e1a9bd7..63e3cd88a980 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.catalyst.plans.logical._ */ class ResolveSubquerySuite extends AnalysisTest { - val a = 'a.int - val b = 'b.int - val c = 'c.int + val a = Symbol("a").int + val b = Symbol("b").int + val c = Symbol("c").int val t1 = LocalRelation(a) val t2 = LocalRelation(b) val t3 = LocalRelation(c) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala index 5ddfa9f2191e..9b5f8c7c3b63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolvedUuidExpressionsSuite.scala @@ -28,11 +28,11 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} */ class ResolvedUuidExpressionsSuite extends AnalysisTest { - private lazy val a = 'a.int + private lazy val a = Symbol("a").int private lazy val r = LocalRelation(a) - private lazy val uuid1 = Uuid().as('_uuid1) - private lazy val uuid2 = Uuid().as('_uuid2) - private lazy val uuid3 = Uuid().as('_uuid3) + private lazy val uuid1 = Uuid().as(Symbol("_uuid1")) + private lazy val uuid2 = Uuid().as(Symbol("_uuid2")) + private lazy val uuid3 = Uuid().as(Symbol("_uuid3")) private lazy val uuid1Ref = uuid1.toAttribute private val tracker = new QueryPlanningTracker diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala index c0312282c76c..5bfd268bd9dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinalsSuite.scala @@ -52,10 +52,10 @@ class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { test("group by ordinal") { // Tests group by ordinal, apply single rule. - val plan2 = testRelation2.groupBy(Literal(1), Literal(2))('a, 'b) + val plan2 = testRelation2.groupBy(Literal(1), Literal(2))(Symbol("a"), Symbol("b")) comparePlans( SubstituteUnresolvedOrdinals.apply(plan2), - testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))('a, 'b)) + testRelation2.groupBy(UnresolvedOrdinal(1), UnresolvedOrdinal(2))(Symbol("a"), Symbol("b"))) // Tests group by ordinal, do full analysis checkAnalysis(plan2, testRelation2.groupBy(a, b)(a, b)) @@ -64,7 +64,7 @@ class SubstituteUnresolvedOrdinalsSuite extends AnalysisTest { withSQLConf(SQLConf.GROUP_BY_ORDINAL.key -> "false") { comparePlans( SubstituteUnresolvedOrdinals.apply(plan2), - testRelation2.groupBy(Literal(1), Literal(2))('a, 'b)) + testRelation2.groupBy(Literal(1), Literal(2))(Symbol("a"), Symbol("b"))) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index d310538e302d..69796c973888 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -498,18 +498,18 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac val tbl2 = catalog.getTable("db2", "tbl2") checkAnswer(tbl2, Seq.empty, Set(part1, part2)) - checkAnswer(tbl2, Seq('a.int <= 1), Set(part1)) - checkAnswer(tbl2, Seq('a.int === 2), Set.empty) - checkAnswer(tbl2, Seq(In('a.int * 10, Seq(30))), Set(part2)) - checkAnswer(tbl2, Seq(Not(In('a.int, Seq(4)))), Set(part1, part2)) - checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "2"), Set(part1)) - checkAnswer(tbl2, Seq('a.int === 1 && 'b.string === "2"), Set(part1)) - checkAnswer(tbl2, Seq('a.int === 1, 'b.string === "x"), Set.empty) - checkAnswer(tbl2, Seq('a.int === 1 || 'b.string === "x"), Set(part1)) + checkAnswer(tbl2, Seq(Symbol("a").int <= 1), Set(part1)) + checkAnswer(tbl2, Seq(Symbol("a").int === 2), Set.empty) + checkAnswer(tbl2, Seq(In(Symbol("a").int * 10, Seq(30))), Set(part2)) + checkAnswer(tbl2, Seq(Not(In(Symbol("a").int, Seq(4)))), Set(part1, part2)) + checkAnswer(tbl2, Seq(Symbol("a").int === 1, Symbol("b").string === "2"), Set(part1)) + checkAnswer(tbl2, Seq(Symbol("a").int === 1 && Symbol("b").string === "2"), Set(part1)) + checkAnswer(tbl2, Seq(Symbol("a").int === 1, Symbol("b").string === "x"), Set.empty) + checkAnswer(tbl2, Seq(Symbol("a").int === 1 || Symbol("b").string === "x"), Set(part1)) intercept[AnalysisException] { try { - checkAnswer(tbl2, Seq('a.int > 0 && 'col1.int > 0), Set.empty) + checkAnswer(tbl2, Seq(Symbol("a").int > 0 && Symbol("col1").int > 0), Set.empty) } catch { // HiveExternalCatalog may be the first one to notice and throw an exception, which will // then be caught and converted to a RuntimeException with a descriptive message. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 577814b9c669..898517bb07a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -54,17 +54,17 @@ class EncoderResolutionSuite extends PlanTest { val encoder = ExpressionEncoder[StringLongClass] // int type can be up cast to long type - val attrs1 = Seq('a.string, 'b.int) + val attrs1 = Seq(Symbol("a").string, Symbol("b").int) testFromRow(encoder, attrs1, InternalRow(str, 1)) // int type can be up cast to string type - val attrs2 = Seq('a.int, 'b.long) + val attrs2 = Seq(Symbol("a").int, Symbol("b").long) testFromRow(encoder, attrs2, InternalRow(1, 2L)) } test("real type doesn't match encoder schema but they are compatible: nested product") { val encoder = ExpressionEncoder[ComplexClass] - val attrs = Seq('a.int, 'b.struct('a.int, 'b.long)) + val attrs = Seq(Symbol("a").int, Symbol("b").struct(Symbol("a").int, Symbol("b").long)) testFromRow(encoder, attrs, InternalRow(1, InternalRow(2, 3L))) } @@ -72,20 +72,20 @@ class EncoderResolutionSuite extends PlanTest { val encoder = ExpressionEncoder.tuple( ExpressionEncoder[StringLongClass], ExpressionEncoder[Long]) - val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int) + val attrs = Seq(Symbol("a").struct(Symbol("a").string, Symbol("b").byte), Symbol("b").int) testFromRow(encoder, attrs, InternalRow(InternalRow(str, 1.toByte), 2)) } test("real type doesn't match encoder schema but they are compatible: primitive array") { val encoder = ExpressionEncoder[PrimitiveArrayClass] - val attrs = Seq('arr.array(IntegerType)) + val attrs = Seq(Symbol("arr").array(IntegerType)) val array = new GenericArrayData(Array(1, 2, 3)) testFromRow(encoder, attrs, InternalRow(array)) } test("the real type is not compatible with encoder schema: primitive array") { val encoder = ExpressionEncoder[PrimitiveArrayClass] - val attrs = Seq('arr.array(StringType)) + val attrs = Seq(Symbol("arr").array(StringType)) assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == s""" |Cannot up cast array element from string to bigint. @@ -99,7 +99,8 @@ class EncoderResolutionSuite extends PlanTest { test("real type doesn't match encoder schema but they are compatible: array") { val encoder = ExpressionEncoder[ArrayClass] - val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int"))) + val attrs = + Seq(Symbol("arr").array(new StructType().add("a", "int").add("b", "int").add("c", "int"))) val array = new GenericArrayData(Array(InternalRow(1, 2, 3))) testFromRow(encoder, attrs, InternalRow(array)) } @@ -108,7 +109,7 @@ class EncoderResolutionSuite extends PlanTest { val encoder = ExpressionEncoder[NestedArrayClass] val et = new StructType().add("arr", ArrayType( new StructType().add("a", "int").add("b", "int").add("c", "int"))) - val attrs = Seq('nestedArr.array(et)) + val attrs = Seq(Symbol("nestedArr").array(et)) val innerArr = new GenericArrayData(Array(InternalRow(1, 2, 3))) val outerArr = new GenericArrayData(Array(InternalRow(innerArr))) testFromRow(encoder, attrs, InternalRow(outerArr)) @@ -116,14 +117,14 @@ class EncoderResolutionSuite extends PlanTest { test("the real type is not compatible with encoder schema: non-array field") { val encoder = ExpressionEncoder[ArrayClass] - val attrs = Seq('arr.int) + val attrs = Seq(Symbol("arr").int) assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "need an array field but got int") } test("the real type is not compatible with encoder schema: array element type") { val encoder = ExpressionEncoder[ArrayClass] - val attrs = Seq('arr.array(new StructType().add("c", "int"))) + val attrs = Seq(Symbol("arr").array(new StructType().add("c", "int"))) assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "No such struct field a in c") } @@ -132,13 +133,13 @@ class EncoderResolutionSuite extends PlanTest { val encoder = ExpressionEncoder[NestedArrayClass] withClue("inner element is not array") { - val attrs = Seq('nestedArr.array(new StructType().add("arr", "int"))) + val attrs = Seq(Symbol("nestedArr").array(new StructType().add("arr", "int"))) assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "need an array field but got int") } withClue("nested array element type is not compatible") { - val attrs = Seq('nestedArr.array(new StructType() + val attrs = Seq(Symbol("nestedArr").array(new StructType() .add("arr", ArrayType(new StructType().add("c", "int"))))) assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "No such struct field a in c") @@ -147,7 +148,7 @@ class EncoderResolutionSuite extends PlanTest { test("nullability of array type element should not fail analysis") { val encoder = ExpressionEncoder[Seq[Int]] - val attrs = 'a.array(IntegerType) :: Nil + val attrs = Symbol("a").array(IntegerType) :: Nil // It should pass analysis val fromRow = encoder.resolveAndBind(attrs).createDeserializer() @@ -166,14 +167,14 @@ class EncoderResolutionSuite extends PlanTest { val encoder = ExpressionEncoder[(String, Long)] { - val attrs = Seq('a.string, 'b.long, 'c.int) + val attrs = Seq(Symbol("a").string, Symbol("b").long, Symbol("c").int) assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct to Tuple2, " + "but failed as the number of fields does not line up.") } { - val attrs = Seq('a.string) + val attrs = Seq(Symbol("a").string) assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct to Tuple2, " + "but failed as the number of fields does not line up.") @@ -184,14 +185,16 @@ class EncoderResolutionSuite extends PlanTest { val encoder = ExpressionEncoder[(String, (Long, String))] { - val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int)) + val attrs = Seq( + Symbol("a").string, + Symbol("b").struct(Symbol("x").long, Symbol("y").string, Symbol("z").int)) assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct to Tuple2, " + "but failed as the number of fields does not line up.") } { - val attrs = Seq('a.string, 'b.struct('x.long)) + val attrs = Seq(Symbol("a").string, Symbol("b").struct(Symbol("x").long)) assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct to Tuple2, " + "but failed as the number of fields does not line up.") @@ -200,14 +203,19 @@ class EncoderResolutionSuite extends PlanTest { test("nested case class can have different number of fields from the real schema") { val encoder = ExpressionEncoder[(String, StringIntClass)] - val attrs = Seq('a.string, 'b.struct('a.string, 'b.int, 'c.int)) + val attrs = Seq( + Symbol("a").string, + Symbol("b").struct(Symbol("a").string, Symbol("b").int, Symbol("c").int)) encoder.resolveAndBind(attrs) } test("SPARK-28497: complex type is not compatible with string encoder schema") { val encoder = ExpressionEncoder[String] - Seq('a.struct('x.long), 'a.array(StringType), 'a.map(StringType, StringType)).foreach { attr => + Seq( + Symbol("a").struct(Symbol("x").long), + Symbol("a").array(StringType), + Symbol("a").map(StringType, StringType)).foreach { attr => val attrs = Seq(attr) assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == s""" @@ -221,7 +229,7 @@ class EncoderResolutionSuite extends PlanTest { test("throw exception if real type is not compatible with encoder schema") { val msg1 = intercept[AnalysisException] { - ExpressionEncoder[StringIntClass].resolveAndBind(Seq('a.string, 'b.long)) + ExpressionEncoder[StringIntClass].resolveAndBind(Seq(Symbol("a").string, Symbol("b").long)) }.message assert(msg1 == s""" @@ -234,7 +242,8 @@ class EncoderResolutionSuite extends PlanTest { val msg2 = intercept[AnalysisException] { val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT) - ExpressionEncoder[ComplexClass].resolveAndBind(Seq('a.long, 'b.struct(structType))) + ExpressionEncoder[ComplexClass].resolveAndBind( + Seq(Symbol("a").long, Symbol("b").struct(structType))) }.message assert(msg2 == s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 14dd04afebe2..b2bf0cd67c1e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -291,11 +291,11 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("function least") { val row = create_row(1, 2, "a", "b", "c") - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.string.at(2) - val c4 = 'a.string.at(3) - val c5 = 'a.string.at(4) + val c1 = Symbol("a").int.at(0) + val c2 = Symbol("a").int.at(1) + val c3 = Symbol("a").string.at(2) + val c4 = Symbol("a").string.at(3) + val c5 = Symbol("a").string.at(4) checkEvaluation(Least(Seq(c4, c3, c5)), "a", row) checkEvaluation(Least(Seq(c1, c2)), 1, row) checkEvaluation(Least(Seq(c1, c2, Literal(-1))), -1, row) @@ -348,11 +348,11 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("function greatest") { val row = create_row(1, 2, "a", "b", "c") - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.string.at(2) - val c4 = 'a.string.at(3) - val c5 = 'a.string.at(4) + val c1 = Symbol("a").int.at(0) + val c2 = Symbol("a").int.at(1) + val c3 = Symbol("a").string.at(2) + val c4 = Symbol("a").string.at(3) + val c5 = Symbol("a").string.at(4) checkEvaluation(Greatest(Seq(c4, c5, c3)), "c", row) checkEvaluation(Greatest(Seq(c2, c1)), 2, row) checkEvaluation(Greatest(Seq(c1, c2, Literal(2))), 2, row) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala index 718d8dd44321..8384c7893dbf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/BitwiseExpressionsSuite.scala @@ -151,11 +151,11 @@ class BitwiseExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row4 = create_row(11.toShort, 16) val row5 = create_row(11.toByte, 16) - val tl = 't.long.at(0) - val ti = 't.int.at(0) - val ts = 't.short.at(0) - val tb = 't.byte.at(0) - val p = 'p.int.at(1) + val tl = Symbol("t").long.at(0) + val ti = Symbol("t").int.at(0) + val ts = Symbol("t").short.at(0) + val tb = Symbol("t").byte.at(0) + val p = Symbol("p").int.at(1) val expr = BitwiseGet(tl, p) checkExceptionInExpression[IllegalArgumentException]( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala index ac31a68b2b61..fabee508d07e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala @@ -99,9 +99,9 @@ class CanonicalizeSuite extends SparkFunSuite { test("SPARK-32927: Bitwise operations are commutative") { Seq(BitwiseOr(_, _), BitwiseAnd(_, _), BitwiseXor(_, _)).foreach { f => - val e1 = f('a, f('b, 'c)) - val e2 = f(f('a, 'b), 'c) - val e3 = f('a, f('b, 'a)) + val e1 = f(Symbol("a"), f(Symbol("b"), Symbol("c"))) + val e2 = f(f(Symbol("a"), Symbol("b")), Symbol("c")) + val e3 = f(Symbol("a"), f(Symbol("b"), Symbol("a"))) assert(e1.canonicalized == e2.canonicalized) assert(e1.canonicalized != e3.canonicalized) @@ -110,9 +110,9 @@ class CanonicalizeSuite extends SparkFunSuite { test("SPARK-32927: Bitwise operations are commutative for non-deterministic expressions") { Seq(BitwiseOr(_, _), BitwiseAnd(_, _), BitwiseXor(_, _)).foreach { f => - val e1 = f('a, f(rand(42), 'c)) - val e2 = f(f('a, rand(42)), 'c) - val e3 = f('a, f(rand(42), 'a)) + val e1 = f(Symbol("a"), f(rand(42), Symbol("c"))) + val e2 = f(f(Symbol("a"), rand(42)), Symbol("c")) + val e3 = f(Symbol("a"), f(rand(42), Symbol("a"))) assert(e1.canonicalized == e2.canonicalized) assert(e1.canonicalized != e3.canonicalized) @@ -121,9 +121,9 @@ class CanonicalizeSuite extends SparkFunSuite { test("SPARK-32927: Bitwise operations are commutative for literal expressions") { Seq(BitwiseOr(_, _), BitwiseAnd(_, _), BitwiseXor(_, _)).foreach { f => - val e1 = f('a, f(42, 'c)) - val e2 = f(f('a, 42), 'c) - val e3 = f('a, f(42, 'a)) + val e1 = f(Symbol("a"), f(42, Symbol("c"))) + val e2 = f(f(Symbol("a"), 42), Symbol("c")) + val e3 = f(Symbol("a"), f(42, Symbol("a"))) assert(e1.canonicalized == e2.canonicalized) assert(e1.canonicalized != e3.canonicalized) @@ -133,9 +133,9 @@ class CanonicalizeSuite extends SparkFunSuite { test("SPARK-32927: Bitwise operations are commutative in a complex case") { Seq(BitwiseOr(_, _), BitwiseAnd(_, _), BitwiseXor(_, _)).foreach { f1 => Seq(BitwiseOr(_, _), BitwiseAnd(_, _), BitwiseXor(_, _)).foreach { f2 => - val e1 = f2(f1('a, f1('b, 'c)), 'a) - val e2 = f2(f1(f1('a, 'b), 'c), 'a) - val e3 = f2(f1('a, f1('b, 'a)), 'a) + val e1 = f2(f1(Symbol("a"), f1(Symbol("b"), Symbol("c"))), Symbol("a")) + val e2 = f2(f1(f1(Symbol("a"), Symbol("b")), Symbol("c")), Symbol("a")) + val e3 = f2(f1(Symbol("a"), f1(Symbol("b"), Symbol("a"))), Symbol("a")) assert(e1.canonicalized == e2.canonicalized) assert(e1.canonicalized != e3.canonicalized) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 57abdb4de229..4a7f8c806ad4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -386,16 +386,16 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { test("CreateStruct") { val row = create_row(1, 2, 3) - val c1 = 'a.int.at(0) - val c3 = 'c.int.at(2) + val c1 = Symbol("a").int.at(0) + val c3 = Symbol("c").int.at(2) checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row) checkEvaluation(CreateStruct(Literal.create(null, LongType) :: Nil), create_row(null)) } test("CreateNamedStruct") { val row = create_row(1, 2, 3) - val c1 = 'a.int.at(0) - val c3 = 'c.int.at(2) + val c1 = Symbol("a").int.at(0) + val c3 = Symbol("c").int.at(2) checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", c3)), create_row(1, 3), row) checkEvaluation(CreateNamedStruct(Seq("a", c1, "b", "y")), create_row(1, UTF8String.fromString("y")), row) @@ -410,11 +410,12 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { ExtractValue(u.child, u.extraction, _ == _) } - checkEvaluation(quickResolve('c.map(MapType(StringType, StringType)).at(0).getItem("a")), + checkEvaluation( + quickResolve(Symbol("c").map(MapType(StringType, StringType)).at(0).getItem("a")), "b", create_row(Map("a" -> "b"))) - checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)), + checkEvaluation(quickResolve(Symbol("c").array(StringType).at(0).getItem(1)), "b", create_row(Seq("a", "b"))) - checkEvaluation(quickResolve('c.struct('a.int).at(0).getField("a")), + checkEvaluation(quickResolve(Symbol("c").struct(Symbol("a").int).at(0).getField("a")), 1, create_row(create_row(1))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index ee6f89a155ae..a4703dd98e6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -72,12 +72,12 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("case when") { val row = create_row(null, false, true, "a", "b", "c") - val c1 = 'a.boolean.at(0) - val c2 = 'a.boolean.at(1) - val c3 = 'a.boolean.at(2) - val c4 = 'a.string.at(3) - val c5 = 'a.string.at(4) - val c6 = 'a.string.at(5) + val c1 = Symbol("a").boolean.at(0) + val c2 = Symbol("a").boolean.at(1) + val c3 = Symbol("a").boolean.at(2) + val c4 = Symbol("a").string.at(3) + val c5 = Symbol("a").string.at(4) + val c6 = Symbol("a").string.at(5) checkEvaluation(CaseWhen(Seq((c1, c4)), c6), "c", row) checkEvaluation(CaseWhen(Seq((c2, c4)), c6), "c", row) @@ -95,9 +95,9 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper assert(CaseWhen(Seq((c2, c4), (c3, c5)), c6).nullable) assert(CaseWhen(Seq((c2, c4), (c3, c5))).nullable) - val c4_notNull = 'a.boolean.notNull.at(3) - val c5_notNull = 'a.boolean.notNull.at(4) - val c6_notNull = 'a.boolean.notNull.at(5) + val c4_notNull = Symbol("a").boolean.notNull.at(3) + val c5_notNull = Symbol("a").boolean.notNull.at(4) + val c6_notNull = Symbol("a").boolean.notNull.at(5) assert(CaseWhen(Seq((c2, c4_notNull)), c6_notNull).nullable === false) assert(CaseWhen(Seq((c2, c4)), c6_notNull).nullable) @@ -186,12 +186,12 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper test("case key when") { val row = create_row(null, 1, 2, "a", "b", "c") - val c1 = 'a.int.at(0) - val c2 = 'a.int.at(1) - val c3 = 'a.int.at(2) - val c4 = 'a.string.at(3) - val c5 = 'a.string.at(4) - val c6 = 'a.string.at(5) + val c1 = Symbol("a").int.at(0) + val c2 = Symbol("a").int.at(1) + val c3 = Symbol("a").int.at(2) + val c4 = Symbol("a").string.at(3) + val c5 = Symbol("a").string.at(4) + val c6 = Symbol("a").string.at(5) val literalNull = Literal.create(null, IntegerType) val literalInt = Literal(1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala index c12dd3051d27..03711274837d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSQLBuilderSuite.scala @@ -94,72 +94,74 @@ class ExpressionSQLBuilderSuite extends SparkFunSuite { } test("attributes") { - checkSQL('a.int, "`a`") + checkSQL(Symbol("a").int, "`a`") checkSQL(Symbol("foo bar").int, "`foo bar`") // Keyword - checkSQL('int.int, "`int`") + checkSQL(Symbol("int").int, "`int`") } test("binary comparisons") { - checkSQL('a.int === 'b.int, "(`a` = `b`)") - checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)") - checkSQL('a.int =!= 'b.int, "(NOT (`a` = `b`))") + checkSQL(Symbol("a").int === Symbol("b").int, "(`a` = `b`)") + checkSQL(Symbol("a").int <=> Symbol("b").int, "(`a` <=> `b`)") + checkSQL(Symbol("a").int =!= Symbol("b").int, "(NOT (`a` = `b`))") - checkSQL('a.int < 'b.int, "(`a` < `b`)") - checkSQL('a.int <= 'b.int, "(`a` <= `b`)") - checkSQL('a.int > 'b.int, "(`a` > `b`)") - checkSQL('a.int >= 'b.int, "(`a` >= `b`)") + checkSQL(Symbol("a").int < Symbol("b").int, "(`a` < `b`)") + checkSQL(Symbol("a").int <= Symbol("b").int, "(`a` <= `b`)") + checkSQL(Symbol("a").int > Symbol("b").int, "(`a` > `b`)") + checkSQL(Symbol("a").int >= Symbol("b").int, "(`a` >= `b`)") - checkSQL('a.int in ('b.int, 'c.int), "(`a` IN (`b`, `c`))") - checkSQL('a.int in (1, 2), "(`a` IN (1, 2))") + checkSQL(Symbol("a").int in (Symbol("b").int, Symbol("c").int), "(`a` IN (`b`, `c`))") + checkSQL(Symbol("a").int in (1, 2), "(`a` IN (1, 2))") - checkSQL('a.int.isNull, "(`a` IS NULL)") - checkSQL('a.int.isNotNull, "(`a` IS NOT NULL)") + checkSQL(Symbol("a").int.isNull, "(`a` IS NULL)") + checkSQL(Symbol("a").int.isNotNull, "(`a` IS NOT NULL)") } test("logical operators") { - checkSQL('a.boolean && 'b.boolean, "(`a` AND `b`)") - checkSQL('a.boolean || 'b.boolean, "(`a` OR `b`)") - checkSQL(!'a.boolean, "(NOT `a`)") - checkSQL(If('a.boolean, 'b.int, 'c.int), "(IF(`a`, `b`, `c`))") + checkSQL(Symbol("a").boolean && Symbol("b").boolean, "(`a` AND `b`)") + checkSQL(Symbol("a").boolean || Symbol("b").boolean, "(`a` OR `b`)") + checkSQL(!Symbol("a").boolean, "(NOT `a`)") + checkSQL(If(Symbol("a").boolean, Symbol("b").int, Symbol("c").int), "(IF(`a`, `b`, `c`))") } test("arithmetic expressions") { - checkSQL('a.int + 'b.int, "(`a` + `b`)") - checkSQL('a.int - 'b.int, "(`a` - `b`)") - checkSQL('a.int * 'b.int, "(`a` * `b`)") - checkSQL('a.int / 'b.int, "(`a` / `b`)") - checkSQL('a.int % 'b.int, "(`a` % `b`)") - - checkSQL(-'a.int, "(- `a`)") - checkSQL(-('a.int + 'b.int), "(- (`a` + `b`))") + checkSQL(Symbol("a").int + Symbol("b").int, "(`a` + `b`)") + checkSQL(Symbol("a").int - Symbol("b").int, "(`a` - `b`)") + checkSQL(Symbol("a").int * Symbol("b").int, "(`a` * `b`)") + checkSQL(Symbol("a").int / Symbol("b").int, "(`a` / `b`)") + checkSQL(Symbol("a").int % Symbol("b").int, "(`a` % `b`)") + + checkSQL(-Symbol("a").int, "(- `a`)") + checkSQL(-(Symbol("a").int + Symbol("b").int), "(- (`a` + `b`))") } test("window specification") { val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) checkSQL( - WindowSpecDefinition('a.int :: Nil, Nil, frame), + WindowSpecDefinition(Symbol("a").int :: Nil, Nil, frame), s"(PARTITION BY `a` ${frame.sql})" ) checkSQL( - WindowSpecDefinition('a.int :: 'b.string :: Nil, Nil, frame), + WindowSpecDefinition(Symbol("a").int :: Symbol("b").string :: Nil, Nil, frame), s"(PARTITION BY `a`, `b` ${frame.sql})" ) checkSQL( - WindowSpecDefinition(Nil, 'a.int.asc :: Nil, frame), + WindowSpecDefinition(Nil, Symbol("a").int.asc :: Nil, frame), s"(ORDER BY `a` ASC NULLS FIRST ${frame.sql})" ) checkSQL( - WindowSpecDefinition(Nil, 'a.int.asc :: 'b.string.desc :: Nil, frame), + WindowSpecDefinition(Nil, Symbol("a").int.asc :: Symbol("b").string.desc :: Nil, frame), s"(ORDER BY `a` ASC NULLS FIRST, `b` DESC NULLS LAST ${frame.sql})" ) checkSQL( - WindowSpecDefinition('a.int :: 'b.string :: Nil, 'c.int.asc :: 'd.string.desc :: Nil, frame), + WindowSpecDefinition( + Symbol("a").int :: Symbol("b").string :: Nil, + Symbol("c").int.asc :: Symbol("d").string.desc :: Nil, frame), s"(PARTITION BY `a`, `b` ORDER BY `c` ASC NULLS FIRST, `d` DESC NULLS LAST ${frame.sql})" ) } @@ -168,17 +170,17 @@ class ExpressionSQLBuilderSuite extends SparkFunSuite { val interval = Literal(new CalendarInterval(0, 0, MICROS_PER_HOUR)) checkSQL( - TimeAdd('a, interval), + TimeAdd(Symbol("a"), interval), "`a` + INTERVAL '1 hours'" ) checkSQL( - DatetimeSub('a, interval, Literal.default(TimestampType)), + DatetimeSub(Symbol("a"), interval, Literal.default(TimestampType)), "`a` - INTERVAL '1 hours'" ) checkSQL( - DateAddInterval('a, interval), + DateAddInterval(Symbol("a"), interval), "`a` + INTERVAL '1 hours'" ) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 4c4df9ef83de..8f215e451150 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -410,11 +410,11 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { testUnary(Bin, java.lang.Long.toBinaryString, (-20 to 20).map(_.toLong), evalType = LongType) val row = create_row(null, 12L, 123L, 1234L, -123L) - val l1 = 'a.long.at(0) - val l2 = 'a.long.at(1) - val l3 = 'a.long.at(2) - val l4 = 'a.long.at(3) - val l5 = 'a.long.at(4) + val l1 = Symbol("a").long.at(0) + val l2 = Symbol("a").long.at(1) + val l3 = Symbol("a").long.at(2) + val l4 = Symbol("a").long.at(3) + val l5 = Symbol("a").long.at(4) checkEvaluation(Bin(l1), null, row) checkEvaluation(Bin(l2), java.lang.Long.toBinaryString(12), row) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index d425d0ba4218..16ca47be62f2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -464,7 +464,7 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // with dummy input, resolve the plan by the analyzer, and replace the dummy input // with a literal for tests. val unresolvedDeser = UnresolvedDeserializer(encoderFor[Map[Int, String]].deserializer) - val dummyInputPlan = LocalRelation('value.map(MapType(IntegerType, StringType))) + val dummyInputPlan = LocalRelation(Symbol("value").map(MapType(IntegerType, StringType))) val plan = Project(Alias(unresolvedDeser, "none")() :: Nil, dummyInputPlan) val analyzedPlan = SimpleAnalyzer.execute(plan) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 8d7501d952ec..0e88830b5ce2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -44,7 +44,7 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { (implicit inputToExpression: A => Expression): Unit = { checkEvaluation(mkExpr(input), expected) // check literal input - val regex = 'a.string.at(0) + val regex = Symbol("a").string.at(0) checkEvaluation(mkExpr(regex), expected, create_row(input)) // check row input } @@ -279,7 +279,7 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { evaluateWithoutCodegen("abbbbc" rlike "**") } intercept[java.util.regex.PatternSyntaxException] { - val regex = 'a.string.at(0) + val regex = Symbol("a").string.at(0) evaluateWithoutCodegen("abbbbc" rlike regex, create_row("**")) } } @@ -292,9 +292,9 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row5 = create_row("100-200", null, "###") val row6 = create_row("100-200", "(-)", null) - val s = 's.string.at(0) - val p = 'p.string.at(1) - val r = 'r.string.at(2) + val s = Symbol("s").string.at(0) + val p = Symbol("p").string.at(1) + val r = Symbol("r").string.at(2) val expr = RegExpReplace(s, p, r) checkEvaluation(expr, "num-num", row1) @@ -344,9 +344,9 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row6 = create_row("100-200", null, 1) val row7 = create_row("100-200", "([a-z])", null) - val s = 's.string.at(0) - val p = 'p.string.at(1) - val r = 'r.int.at(2) + val s = Symbol("s").string.at(0) + val p = Symbol("p").string.at(1) + val r = Symbol("r").int.at(2) val expr = RegExpExtract(s, p, r) checkEvaluation(expr, "100", row1) @@ -396,9 +396,9 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row7 = create_row("100-200,300-400,500-600", null, 1) val row8 = create_row("100-200,300-400,500-600", "([a-z])", null) - val s = 's.string.at(0) - val p = 'p.string.at(1) - val r = 'r.int.at(2) + val s = Symbol("s").string.at(0) + val p = Symbol("p").string.at(1) + val r = Symbol("r").int.at(2) val expr = RegExpExtractAll(s, p, r) checkEvaluation(expr, Seq("100-200", "300-400", "500-600"), row1) @@ -437,8 +437,8 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPLIT") { - val s1 = 'a.string.at(0) - val s2 = 'b.string.at(1) + val s1 = Symbol("a").string.at(0) + val s2 = Symbol("b").string.at(1) val row1 = create_row("aa2bb3cc", "[1-9]+") val row2 = create_row(null, "[1-9]+") val row3 = create_row("aa2bb3cc", null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index d25240fec13d..834e249d3e11 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -139,8 +139,8 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("StringComparison") { val row = create_row("abc", null) - val c1 = 'a.string.at(0) - val c2 = 'a.string.at(1) + val c1 = Symbol("a").string.at(0) + val c2 = Symbol("a").string.at(1) checkEvaluation(c1 contains "b", true, row) checkEvaluation(c1 contains "x", false, row) @@ -166,7 +166,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Substring") { val row = create_row("example", "example".toArray.map(_.toByte)) - val s = 'a.string.at(0) + val s = Symbol("a").string.at(0) // substring from zero position with less-than-full length checkEvaluation( @@ -240,7 +240,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Substring(s, Literal.create(-1207959552, IntegerType), Literal.create(-1207959552, IntegerType)), "", row) - val s_notNull = 'a.string.notNull.at(0) + val s_notNull = Symbol("a").string.notNull.at(0) assert(Substring(s, Literal.create(0, IntegerType), Literal.create(2, IntegerType)).nullable) assert( @@ -302,7 +302,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("ascii for string") { - val a = 'a.string.at(0) + val a = Symbol("a").string.at(0) checkEvaluation(Ascii(Literal("efg")), 101, create_row("abdef")) checkEvaluation(Ascii(a), 97, create_row("abdef")) checkEvaluation(Ascii(a), 0, create_row("")) @@ -311,7 +311,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("string for ascii") { - val a = 'a.long.at(0) + val a = Symbol("a").long.at(0) checkEvaluation(Chr(Literal(48L)), "0", create_row("abdef")) checkEvaluation(Chr(a), "a", create_row(97L)) checkEvaluation(Chr(a), "a", create_row(97L + 256L)) @@ -324,8 +324,8 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("base64/unbase64 for string") { - val a = 'a.string.at(0) - val b = 'b.binary.at(0) + val a = Symbol("a").string.at(0) + val b = Symbol("b").binary.at(0) val bytes = Array[Byte](1, 2, 3, 4) checkEvaluation(Base64(Literal(bytes)), "AQIDBA==", create_row("abdef")) @@ -344,8 +344,8 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("encode/decode for string") { - val a = 'a.string.at(0) - val b = 'b.binary.at(0) + val a = Symbol("a").string.at(0) + val b = Symbol("b").binary.at(0) // scalastyle:off // non ascii characters are not allowed in the code, so we disable the scalastyle here. checkEvaluation( @@ -561,7 +561,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("TRIM") { - val s = 'a.string.at(0) + val s = Symbol("a").string.at(0) checkEvaluation(StringTrim(Literal(" aa ")), "aa", create_row(" abdef ")) checkEvaluation(StringTrim("aa", "a"), "", create_row(" abdef ")) checkEvaluation(StringTrim(Literal(" aabbtrimccc"), "ab cd"), "trim", create_row("bdef")) @@ -592,7 +592,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("LTRIM") { - val s = 'a.string.at(0) + val s = Symbol("a").string.at(0) checkEvaluation(StringTrimLeft(Literal(" aa ")), "aa ", create_row(" abdef ")) checkEvaluation(StringTrimLeft(Literal("aa"), "a"), "", create_row(" abdef ")) checkEvaluation(StringTrimLeft(Literal("aa "), "a "), "", create_row(" abdef ")) @@ -624,7 +624,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("RTRIM") { - val s = 'a.string.at(0) + val s = Symbol("a").string.at(0) checkEvaluation(StringTrimRight(Literal(" aa ")), " aa", create_row(" abdef ")) checkEvaluation(StringTrimRight(Literal("a"), "a"), "", create_row(" abdef ")) checkEvaluation(StringTrimRight(Literal("ab"), "ab"), "", create_row(" abdef ")) @@ -681,9 +681,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("INSTR") { - val s1 = 'a.string.at(0) - val s2 = 'b.string.at(1) - val s3 = 'c.string.at(2) + val s1 = Symbol("a").string.at(0) + val s2 = Symbol("b").string.at(1) + val s3 = Symbol("c").string.at(2) val row1 = create_row("aaads", "aa", "zz") checkEvaluation(StringInstr(Literal("aaads"), Literal("aa")), 1, row1) @@ -706,10 +706,10 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("LOCATE") { - val s1 = 'a.string.at(0) - val s2 = 'b.string.at(1) - val s3 = 'c.string.at(2) - val s4 = 'd.int.at(3) + val s1 = Symbol("a").string.at(0) + val s2 = Symbol("b").string.at(1) + val s3 = Symbol("c").string.at(2) + val s4 = Symbol("d").int.at(3) val row1 = create_row("aaads", "aa", "zz", 2) val row2 = create_row(null, "aa", "zz", 1) val row3 = create_row("aaads", null, "zz", 1) @@ -733,9 +733,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("LPAD/RPAD") { - val s1 = 'a.string.at(0) - val s2 = 'b.int.at(1) - val s3 = 'c.string.at(2) + val s1 = Symbol("a").string.at(0) + val s2 = Symbol("b").int.at(1) + val s3 = Symbol("c").string.at(2) val row1 = create_row("hi", 5, "??") val row2 = create_row("hi", 1, "?") val row3 = create_row(null, 1, "?") @@ -768,8 +768,8 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("REPEAT") { - val s1 = 'a.string.at(0) - val s2 = 'b.int.at(1) + val s1 = Symbol("a").string.at(0) + val s2 = Symbol("b").int.at(1) val row1 = create_row("hi", 2) val row2 = create_row(null, 1) @@ -783,7 +783,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("REVERSE") { - val s = 'a.string.at(0) + val s = Symbol("a").string.at(0) val row1 = create_row("abccc") checkEvaluation(Reverse(Literal("abccc")), "cccba", row1) checkEvaluation(Reverse(s), "cccba", row1) @@ -791,7 +791,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("SPACE") { - val s1 = 'b.int.at(0) + val s1 = Symbol("b").int.at(0) val row1 = create_row(2) val row2 = create_row(null) @@ -803,8 +803,8 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("length for string / binary") { - val a = 'a.string.at(0) - val b = 'b.binary.at(0) + val a = Symbol("a").string.at(0) + val b = Symbol("b").binary.at(0) val bytes = Array[Byte](1, 2, 3, 1, 2) val string = "abdef" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala index 53e8ee9fbe71..a4eb95218bec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala @@ -282,7 +282,7 @@ class ApproximatePercentileSuite extends SparkFunSuite { } test("class ApproximatePercentile, automatically add type casting for parameters") { - val testRelation = LocalRelation('a.int) + val testRelation = LocalRelation(Symbol("a").int) // accuracy types must be integral, no type casting val accuracyExpressions = Seq( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index 8984bad479a6..dd5ad8af2313 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -35,12 +35,13 @@ class AggregateOptimizeSuite extends AnalysisTest { RemoveRepetitionFromGroupExpressions) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) test("remove literals in grouping expression") { - val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b)) + val query = + testRelation.groupBy(Symbol("a"), Literal("1"), Literal(1) + Literal(2))(sum(Symbol("b"))) val optimized = Optimize.execute(analyzer.execute(query)) - val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze + val correctAnswer = testRelation.groupBy(Symbol("a"))(sum(Symbol("b"))).analyze comparePlans(optimized, correctAnswer) } @@ -48,26 +49,37 @@ class AggregateOptimizeSuite extends AnalysisTest { test("do not remove all grouping expressions if they are all literals") { withSQLConf(CASE_SENSITIVE.key -> "false", GROUP_BY_ORDINAL.key -> "false") { val analyzer = getAnalyzer - val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b)) + val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum(Symbol("b"))) val optimized = Optimize.execute(analyzer.execute(query)) - val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b))) + val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum(Symbol("b")))) comparePlans(optimized, correctAnswer) } } test("Remove aliased literals") { - val query = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a, 'y)(sum('b)) + val query = testRelation.select( + Symbol("a"), + Symbol("b"), + Literal(1).as(Symbol("y"))).groupBy(Symbol("a"), Symbol("y"))(sum(Symbol("b"))) val optimized = Optimize.execute(analyzer.execute(query)) - val correctAnswer = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a)(sum('b)).analyze + val correctAnswer = testRelation.select( + Symbol("a"), + Symbol("b"), + Literal(1).as(Symbol("y"))).groupBy(Symbol("a"))(sum(Symbol("b"))).analyze comparePlans(optimized, correctAnswer) } test("remove repetition in grouping expression") { - val query = testRelation.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c)) + val query = testRelation.groupBy( + Symbol("a") + 1, + Symbol("b") + 2, + Literal(1) + Symbol("A"), + Literal(2) + Symbol("B"))(sum(Symbol("c"))) val optimized = Optimize.execute(analyzer.execute(query)) - val correctAnswer = testRelation.groupBy('a + 1, 'b + 2)(sum('c)).analyze + val correctAnswer = + testRelation.groupBy(Symbol("a") + 1, Symbol("b") + 2)(sum(Symbol("c"))).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala index c02691848c8f..7dfc2bf7f9bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -44,8 +44,8 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper PruneFilters) :: Nil } - val nullableRelation = LocalRelation('a.int.withNullability(true)) - val nonNullableRelation = LocalRelation('a.int.withNullability(false)) + val nullableRelation = LocalRelation(Symbol("a").int.withNullability(true)) + val nonNullableRelation = LocalRelation(Symbol("a").int.withNullability(false)) test("Preserve nullable exprs when constraintPropagation is false") { withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { @@ -68,15 +68,16 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper } test("Nullable Simplification Primitive: <=>") { - val plan = nullableRelation.select('a <=> 'a).analyze + val plan = nullableRelation.select(Symbol("a") <=> Symbol("a")).analyze val actual = Optimize.execute(plan) val correctAnswer = nullableRelation.select(Alias(TrueLiteral, "(a <=> a)")()).analyze comparePlans(actual, correctAnswer) } test("Non-Nullable Simplification Primitive") { - val plan = nonNullableRelation - .select('a === 'a, 'a <=> 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a).analyze + val plan = nonNullableRelation.select( + Symbol("a") === Symbol("a"), Symbol("a") <=> Symbol("a"), Symbol("a") <= Symbol("a"), + Symbol("a") >= Symbol("a"), Symbol("a") < Symbol("a"), Symbol("a") > Symbol("a")).analyze val actual = Optimize.execute(plan) val correctAnswer = nonNullableRelation .select( @@ -92,9 +93,9 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper test("Expression Normalization") { val plan = nonNullableRelation.where( - 'a * Literal(100) + Pi() === Pi() + Literal(100) * 'a && - DateAdd(CurrentDate(), 'a + Literal(2)) <= DateAdd(CurrentDate(), Literal(2) + 'a)) - .analyze + Symbol("a") * Literal(100) + Pi() === Pi() + Literal(100) * Symbol("a") && + DateAdd(CurrentDate(), Symbol("a") + Literal(2)) <= + DateAdd(CurrentDate(), Literal(2) + Symbol("a"))).analyze val actual = Optimize.execute(plan) val correctAnswer = nonNullableRelation.analyze comparePlans(actual, correctAnswer) @@ -140,7 +141,7 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper Seq(a === a, a <= a, a >= a).foreach { condition => val plan = nullableRelation.where(condition).analyze val actual = Optimize.execute(plan) - val correctAnswer = nullableRelation.where('a.isNotNull).analyze + val correctAnswer = nullableRelation.where(Symbol("a").isNotNull).analyze comparePlans(actual, correctAnswer) } @@ -160,7 +161,7 @@ class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper And(a >= a, a.isNotNull)).foreach { condition => val plan = nullableRelation.where(condition).analyze val actual = Optimize.execute(plan) - val correctAnswer = nullableRelation.where('a.isNotNull).analyze + val correctAnswer = nullableRelation.where(Symbol("a").isNotNull).analyze comparePlans(actual, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 04dcf50e0c3c..879db45278b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -42,16 +42,18 @@ class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper with PruneFilters) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string, - 'e.boolean, 'f.boolean, 'g.boolean, 'h.boolean) + val testRelation = LocalRelation( + Symbol("a").int, Symbol("b").int, Symbol("c").int, Symbol("d").string, + Symbol("e").boolean, Symbol("f").boolean, Symbol("g").boolean, Symbol("h").boolean) val testRelationWithData = LocalRelation.fromExternalRows( testRelation.output, Seq(Row(1, 2, 3, "abc")) ) - val testNotNullableRelation = LocalRelation('a.int.notNull, 'b.int.notNull, 'c.int.notNull, - 'd.string.notNull, 'e.boolean.notNull, 'f.boolean.notNull, 'g.boolean.notNull, - 'h.boolean.notNull) + val testNotNullableRelation = LocalRelation( + Symbol("a").int.notNull, Symbol("b").int.notNull, Symbol("c").int.notNull, + Symbol("d").string.notNull, Symbol("e").boolean.notNull, Symbol("f").boolean.notNull, + Symbol("g").boolean.notNull, Symbol("h").boolean.notNull) val testNotNullableRelationWithData = LocalRelation.fromExternalRows( testNotNullableRelation.output, Seq(Row(1, 2, 3, "abc")) @@ -86,105 +88,138 @@ class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper with } test("a && a => a") { - checkCondition(Literal(1) < 'a && Literal(1) < 'a, Literal(1) < 'a) - checkCondition(Literal(1) < 'a && Literal(1) < 'a && Literal(1) < 'a, Literal(1) < 'a) + checkCondition(Literal(1) < Symbol("a") && Literal(1) < Symbol("a"), Literal(1) < Symbol("a")) + checkCondition(Literal(1) < Symbol("a") && Literal(1) < Symbol("a") && + Literal(1) < Symbol("a"), Literal(1) < Symbol("a")) } test("a || a => a") { - checkCondition(Literal(1) < 'a || Literal(1) < 'a, Literal(1) < 'a) - checkCondition(Literal(1) < 'a || Literal(1) < 'a || Literal(1) < 'a, Literal(1) < 'a) + checkCondition(Literal(1) < Symbol("a") || Literal(1) < Symbol("a"), Literal(1) < Symbol("a")) + checkCondition(Literal(1) < Symbol("a") || Literal(1) < Symbol("a") || + Literal(1) < Symbol("a"), Literal(1) < Symbol("a")) } test("(a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ...") { - checkCondition('b > 3 || 'c > 5, 'b > 3 || 'c > 5) + checkCondition(Symbol("b") > 3 || Symbol("c") > 5, Symbol("b") > 3 || Symbol("c") > 5) - checkCondition(('a < 2 && 'a > 3 && 'b > 5) || 'a < 2, 'a < 2) + checkCondition((Symbol("a") < 2 && Symbol("a") > 3 && + Symbol("b") > 5) || Symbol("a") < 2, Symbol("a") < 2) - checkCondition('a < 2 || ('a < 2 && 'a > 3 && 'b > 5), 'a < 2) + checkCondition(Symbol("a") < 2 || (Symbol("a") < 2 && + Symbol("a") > 3 && Symbol("b") > 5), Symbol("a") < 2) - val input = ('a === 'b && 'b > 3 && 'c > 2) || - ('a === 'b && 'c < 1 && 'a === 5) || - ('a === 'b && 'b < 5 && 'a > 1) + val input = (Symbol("a") === Symbol("b") && Symbol("b") > 3 && Symbol("c") > 2) || + (Symbol("a") === Symbol("b") && Symbol("c") < 1 && Symbol("a") === 5) || + (Symbol("a") === Symbol("b") && Symbol("b") < 5 && Symbol("a") > 1) - val expected = 'a === 'b && ( - ('b > 3 && 'c > 2) || ('c < 1 && 'a === 5) || ('b < 5 && 'a > 1)) + val expected = Symbol("a") === Symbol("b") && + ((Symbol("b") > 3 && Symbol("c") > 2) || (Symbol("c") < 1 && + Symbol("a") === 5) || (Symbol("b") < 5 && Symbol("a") > 1)) checkCondition(input, expected) } test("(a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ...") { - checkCondition('b > 3 && 'c > 5, 'b > 3 && 'c > 5) + checkCondition(Symbol("b") > 3 && Symbol("c") > 5, Symbol("b") > 3 && Symbol("c") > 5) - checkCondition(('a < 2 || 'a > 3 || 'b > 5) && 'a < 2, 'a < 2) + checkCondition((Symbol("a") < 2 || Symbol("a") > 3 || + Symbol("b") > 5) && Symbol("a") < 2, Symbol("a") < 2) - checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5), 'a < 2) + checkCondition(Symbol("a") < 2 && + (Symbol("a") < 2 || Symbol("a") > 3 || Symbol("b") > 5), Symbol("a") < 2) - checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5)) + checkCondition((Symbol("a") < 2 || Symbol("b") > 3) && + (Symbol("a") < 2 || Symbol("c") > 5), Symbol("a") < 2 || (Symbol("b") > 3 && Symbol("c") > 5)) checkCondition( - ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5), - 'a === 'b || 'b > 3 && 'a > 3 && 'a < 5) + (Symbol("a") === Symbol("b") || Symbol("b") > 3) && + (Symbol("a") === Symbol("b") || Symbol("a") > 3) && + (Symbol("a") === Symbol("b") || Symbol("a") < 5), + Symbol("a") === Symbol("b") || Symbol("b") > 3 && Symbol("a") > 3 && Symbol("a") < 5) } test("e && (!e || f) - not nullable") { - checkConditionInNotNullableRelation('e && (!'e || 'f ), 'e && 'f) + checkConditionInNotNullableRelation(Symbol("e") && + (!Symbol("e") || Symbol("f") ), Symbol("e") && Symbol("f")) - checkConditionInNotNullableRelation('e && ('f || !'e ), 'e && 'f) + checkConditionInNotNullableRelation(Symbol("e") && + (Symbol("f") || !Symbol("e") ), Symbol("e") && Symbol("f")) - checkConditionInNotNullableRelation((!'e || 'f ) && 'e, 'f && 'e) + checkConditionInNotNullableRelation((!Symbol("e") || Symbol("f") ) && + Symbol("e"), Symbol("f") && Symbol("e")) - checkConditionInNotNullableRelation(('f || !'e ) && 'e, 'f && 'e) + checkConditionInNotNullableRelation((Symbol("f") || !Symbol("e") ) && + Symbol("e"), Symbol("f") && Symbol("e")) } test("e && (!e || f) - nullable") { - Seq ('e && (!'e || 'f ), - 'e && ('f || !'e ), - (!'e || 'f ) && 'e, - ('f || !'e ) && 'e, - 'e || (!'e && 'f), - 'e || ('f && !'e), - ('e && 'f) || !'e, - ('f && 'e) || !'e).foreach { expr => + Seq (Symbol("e") && (!Symbol("e") || Symbol("f") ), + Symbol("e") && (Symbol("f") || !Symbol("e") ), + (!Symbol("e") || Symbol("f") ) && Symbol("e"), + (Symbol("f") || !Symbol("e") ) && Symbol("e"), + Symbol("e") || (!Symbol("e") && Symbol("f")), + Symbol("e") || (Symbol("f") && !Symbol("e")), + (Symbol("e") && Symbol("f")) || !Symbol("e"), + (Symbol("f") && Symbol("e")) || !Symbol("e")).foreach { expr => checkCondition(expr, expr) } } test("a < 1 && (!(a < 1) || f) - not nullable") { - checkConditionInNotNullableRelation('a < 1 && (!('a < 1) || 'f), ('a < 1) && 'f) - checkConditionInNotNullableRelation('a < 1 && ('f || !('a < 1)), ('a < 1) && 'f) - - checkConditionInNotNullableRelation('a <= 1 && (!('a <= 1) || 'f), ('a <= 1) && 'f) - checkConditionInNotNullableRelation('a <= 1 && ('f || !('a <= 1)), ('a <= 1) && 'f) - - checkConditionInNotNullableRelation('a > 1 && (!('a > 1) || 'f), ('a > 1) && 'f) - checkConditionInNotNullableRelation('a > 1 && ('f || !('a > 1)), ('a > 1) && 'f) - - checkConditionInNotNullableRelation('a >= 1 && (!('a >= 1) || 'f), ('a >= 1) && 'f) - checkConditionInNotNullableRelation('a >= 1 && ('f || !('a >= 1)), ('a >= 1) && 'f) + checkConditionInNotNullableRelation(Symbol("a") < 1 && + (!(Symbol("a") < 1) || Symbol("f")), (Symbol("a") < 1) && Symbol("f")) + checkConditionInNotNullableRelation(Symbol("a") < 1 && + (Symbol("f") || !(Symbol("a") < 1)), (Symbol("a") < 1) && Symbol("f")) + + checkConditionInNotNullableRelation(Symbol("a") <= 1 && + (!(Symbol("a") <= 1) || Symbol("f")), (Symbol("a") <= 1) && Symbol("f")) + checkConditionInNotNullableRelation(Symbol("a") <= 1 && + (Symbol("f") || !(Symbol("a") <= 1)), (Symbol("a") <= 1) && Symbol("f")) + + checkConditionInNotNullableRelation(Symbol("a") > 1 && + (!(Symbol("a") > 1) || Symbol("f")), (Symbol("a") > 1) && Symbol("f")) + checkConditionInNotNullableRelation(Symbol("a") > 1 && + (Symbol("f") || !(Symbol("a") > 1)), (Symbol("a") > 1) && Symbol("f")) + + checkConditionInNotNullableRelation(Symbol("a") >= 1 && + (!(Symbol("a") >= 1) || Symbol("f")), (Symbol("a") >= 1) && Symbol("f")) + checkConditionInNotNullableRelation(Symbol("a") >= 1 && + (Symbol("f") || !(Symbol("a") >= 1)), (Symbol("a") >= 1) && Symbol("f")) } test("a < 1 && ((a >= 1) || f) - not nullable") { - checkConditionInNotNullableRelation('a < 1 && ('a >= 1 || 'f ), ('a < 1) && 'f) - checkConditionInNotNullableRelation('a < 1 && ('f || 'a >= 1), ('a < 1) && 'f) - - checkConditionInNotNullableRelation('a <= 1 && ('a > 1 || 'f ), ('a <= 1) && 'f) - checkConditionInNotNullableRelation('a <= 1 && ('f || 'a > 1), ('a <= 1) && 'f) - - checkConditionInNotNullableRelation('a > 1 && (('a <= 1) || 'f), ('a > 1) && 'f) - checkConditionInNotNullableRelation('a > 1 && ('f || ('a <= 1)), ('a > 1) && 'f) - - checkConditionInNotNullableRelation('a >= 1 && (('a < 1) || 'f), ('a >= 1) && 'f) - checkConditionInNotNullableRelation('a >= 1 && ('f || ('a < 1)), ('a >= 1) && 'f) + checkConditionInNotNullableRelation(Symbol("a") < 1 && + (Symbol("a") >= 1 || Symbol("f") ), (Symbol("a") < 1) && Symbol("f")) + checkConditionInNotNullableRelation(Symbol("a") < 1 && + (Symbol("f") || Symbol("a") >= 1), (Symbol("a") < 1) && Symbol("f")) + + checkConditionInNotNullableRelation(Symbol("a") <= 1 && + (Symbol("a") > 1 || Symbol("f") ), (Symbol("a") <= 1) && Symbol("f")) + checkConditionInNotNullableRelation(Symbol("a") <= 1 && + (Symbol("f") || Symbol("a") > 1), (Symbol("a") <= 1) && Symbol("f")) + + checkConditionInNotNullableRelation(Symbol("a") > 1 && + ((Symbol("a") <= 1) || Symbol("f")), (Symbol("a") > 1) && Symbol("f")) + checkConditionInNotNullableRelation(Symbol("a") > 1 && + (Symbol("f") || (Symbol("a") <= 1)), (Symbol("a") > 1) && Symbol("f")) + + checkConditionInNotNullableRelation(Symbol("a") >= 1 && + ((Symbol("a") < 1) || Symbol("f")), (Symbol("a") >= 1) && Symbol("f")) + checkConditionInNotNullableRelation(Symbol("a") >= 1 && + (Symbol("f") || (Symbol("a") < 1)), (Symbol("a") >= 1) && Symbol("f")) } test("DeMorgan's law") { - checkCondition(!('e && 'f), !'e || !'f) + checkCondition(!(Symbol("e") && Symbol("f")), !Symbol("e") || !Symbol("f")) - checkCondition(!('e || 'f), !'e && !'f) + checkCondition(!(Symbol("e") || Symbol("f")), !Symbol("e") && !Symbol("f")) - checkCondition(!(('e && 'f) || ('g && 'h)), (!'e || !'f) && (!'g || !'h)) + checkCondition(!((Symbol("e") && Symbol("f")) || (Symbol("g") && + Symbol("h"))), (!Symbol("e") || !Symbol("f")) && (!Symbol("g") || !Symbol("h"))) - checkCondition(!(('e || 'f) && ('g || 'h)), (!'e && !'f) || (!'g && !'h)) + checkCondition(!((Symbol("e") || Symbol("f")) && + (Symbol("g") || Symbol("h"))), (!Symbol("e") && + !Symbol("f")) || (!Symbol("g") && !Symbol("h"))) } private val analyzer = new Analyzer( @@ -192,53 +227,62 @@ class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper with test("(a && b) || (a && c) => a && (b || c) when case insensitive") { val plan = analyzer.execute( - testRelation.where(('a > 2 && 'b > 3) || ('A > 2 && 'b < 5))) + testRelation.where((Symbol("a") > 2 && + Symbol("b") > 3) || (Symbol("a") > 2 && Symbol("b") < 5))) val actual = Optimize.execute(plan) val expected = analyzer.execute( - testRelation.where('a > 2 && ('b > 3 || 'b < 5))) + testRelation.where(Symbol("a") > 2 && + (Symbol("b") > 3 || Symbol("b") < 5))) comparePlans(actual, expected) } test("(a || b) && (a || c) => a || (b && c) when case insensitive") { val plan = analyzer.execute( - testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5))) + testRelation.where((Symbol("a") > 2 || Symbol("b") > 3) && + (Symbol("a") > 2 || Symbol("b") < 5))) val actual = Optimize.execute(plan) val expected = analyzer.execute( - testRelation.where('a > 2 || ('b > 3 && 'b < 5))) + testRelation.where(Symbol("a") > 2 || (Symbol("b") > 3 && Symbol("b") < 5))) comparePlans(actual, expected) } test("Complementation Laws") { - checkConditionInNotNullableRelation('e && !'e, testNotNullableRelation) - checkConditionInNotNullableRelation(!'e && 'e, testNotNullableRelation) + checkConditionInNotNullableRelation(Symbol("e") && !Symbol("e"), testNotNullableRelation) + checkConditionInNotNullableRelation(!Symbol("e") && Symbol("e"), testNotNullableRelation) - checkConditionInNotNullableRelation('e || !'e, testNotNullableRelationWithData) - checkConditionInNotNullableRelation(!'e || 'e, testNotNullableRelationWithData) + checkConditionInNotNullableRelation( + Symbol("e") || !Symbol("e"), testNotNullableRelationWithData) + checkConditionInNotNullableRelation( + !Symbol("e") || Symbol("e"), testNotNullableRelationWithData) } test("Complementation Laws - null handling") { - checkCondition('e && !'e, - testRelationWithData.where(And(Literal(null, BooleanType), 'e.isNull)).analyze) - checkCondition(!'e && 'e, - testRelationWithData.where(And(Literal(null, BooleanType), 'e.isNull)).analyze) - - checkCondition('e || !'e, - testRelationWithData.where(Or('e.isNotNull, Literal(null, BooleanType))).analyze) - checkCondition(!'e || 'e, - testRelationWithData.where(Or('e.isNotNull, Literal(null, BooleanType))).analyze) + checkCondition(Symbol("e") && !Symbol("e"), + testRelationWithData.where(And(Literal(null, BooleanType), Symbol("e").isNull)).analyze) + checkCondition(!Symbol("e") && Symbol("e"), + testRelationWithData.where(And(Literal(null, BooleanType), Symbol("e").isNull)).analyze) + + checkCondition(Symbol("e") || !Symbol("e"), + testRelationWithData.where(Or(Symbol("e").isNotNull, Literal(null, BooleanType))).analyze) + checkCondition(!Symbol("e") || Symbol("e"), + testRelationWithData.where(Or(Symbol("e").isNotNull, Literal(null, BooleanType))).analyze) } test("Complementation Laws - negative case") { - checkCondition('e && !'f, testRelationWithData.where('e && !'f).analyze) - checkCondition(!'f && 'e, testRelationWithData.where(!'f && 'e).analyze) - - checkCondition('e || !'f, testRelationWithData.where('e || !'f).analyze) - checkCondition(!'f || 'e, testRelationWithData.where(!'f || 'e).analyze) + checkCondition(Symbol("e") && !Symbol("f"), + testRelationWithData.where(Symbol("e") && !Symbol("f")).analyze) + checkCondition(!Symbol("f") && Symbol("e"), + testRelationWithData.where(!Symbol("f") && Symbol("e")).analyze) + + checkCondition(Symbol("e") || !Symbol("f"), + testRelationWithData.where(Symbol("e") || !Symbol("f")).analyze) + checkCondition(!Symbol("f") || Symbol("e"), + testRelationWithData.where(!Symbol("f") || Symbol("e")).analyze) } test("simplify NOT(IsNull(x)) and NOT(IsNotNull(x))") { - checkCondition(Not(IsNotNull('b)), IsNull('b)) - checkCondition(Not(IsNull('b)), IsNotNull('b)) + checkCondition(Not(IsNotNull(Symbol("b"))), IsNull(Symbol("b"))) + checkCondition(Not(IsNull(Symbol("b"))), IsNotNull(Symbol("b"))) } protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { @@ -249,8 +293,8 @@ class BooleanSimplificationSuite extends PlanTest with ExpressionEvalHelper with test("filter reduction - positive cases") { val fields = Seq( - 'col1NotNULL.boolean.notNull, - 'col2NotNULL.boolean.notNull + Symbol("col1NotNULL").boolean.notNull, + Symbol("col2NotNULL").boolean.notNull ) val Seq(col1NotNULL, col2NotNULL) = fields.zipWithIndex.map { case (f, i) => f.at(i) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala index dea2b36ecc84..ec91376b08e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CheckCartesianProductsSuite.scala @@ -34,11 +34,11 @@ class CheckCartesianProductsSuite extends PlanTest { val batches = Batch("Check Cartesian Products", Once, CheckCartesianProducts) :: Nil } - val testRelation1 = LocalRelation('a.int, 'b.int) - val testRelation2 = LocalRelation('c.int, 'd.int) + val testRelation1 = LocalRelation(Symbol("a").int, Symbol("b").int) + val testRelation2 = LocalRelation(Symbol("c").int, Symbol("d").int) val joinTypesWithRequiredCondition = Seq(Inner, LeftOuter, RightOuter, FullOuter) - val joinTypesWithoutRequiredCondition = Seq(LeftSemi, LeftAnti, ExistenceJoin('exists)) + val joinTypesWithoutRequiredCondition = Seq(LeftSemi, LeftAnti, ExistenceJoin(Symbol("exists"))) test("CheckCartesianProducts doesn't throw an exception if cross joins are enabled)") { withSQLConf(CROSS_JOINS_ENABLED.key -> "true") { @@ -65,7 +65,7 @@ class CheckCartesianProductsSuite extends PlanTest { withSQLConf(CROSS_JOINS_ENABLED.key -> "false") { for (joinType <- joinTypesWithRequiredCondition) { noException should be thrownBy { - performCartesianProductCheck(joinType, Some('a === 'd)) + performCartesianProductCheck(joinType, Some(Symbol("a") === Symbol("d"))) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 42bcd13ee378..f7a07baac781 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -33,37 +33,38 @@ class CollapseProjectSuite extends PlanTest { Batch("CollapseProject", Once, CollapseProject) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int) test("collapse two deterministic, independent projects into one") { val query = testRelation - .select(('a + 1).as('a_plus_1), 'b) - .select('a_plus_1, ('b + 1).as('b_plus_1)) + .select((Symbol("a") + 1).as(Symbol("a_plus_1")), Symbol("b")) + .select(Symbol("a_plus_1"), (Symbol("b") + 1).as(Symbol("b_plus_1"))) val optimized = Optimize.execute(query.analyze) - val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze + val correctAnswer = testRelation.select((Symbol("a") + 1).as(Symbol("a_plus_1")), + (Symbol("b") + 1).as(Symbol("b_plus_1"))).analyze comparePlans(optimized, correctAnswer) } test("collapse two deterministic, dependent projects into one") { val query = testRelation - .select(('a + 1).as('a_plus_1), 'b) - .select(('a_plus_1 + 1).as('a_plus_2), 'b) + .select((Symbol("a") + 1).as(Symbol("a_plus_1")), Symbol("b")) + .select((Symbol("a_plus_1") + 1).as(Symbol("a_plus_2")), Symbol("b")) val optimized = Optimize.execute(query.analyze) val correctAnswer = testRelation.select( - (('a + 1).as('a_plus_1) + 1).as('a_plus_2), - 'b).analyze + ((Symbol("a") + 1).as(Symbol("a_plus_1")) + 1).as(Symbol("a_plus_2")), + Symbol("b")).analyze comparePlans(optimized, correctAnswer) } test("do not collapse nondeterministic projects") { val query = testRelation - .select(Rand(10).as('rand)) - .select(('rand + 1).as('rand1), ('rand + 2).as('rand2)) + .select(Rand(10).as(Symbol("rand"))) + .select((Symbol("rand") + 1).as(Symbol("rand1")), (Symbol("rand") + 2).as(Symbol("rand2"))) val optimized = Optimize.execute(query.analyze) val correctAnswer = query.analyze @@ -73,47 +74,48 @@ class CollapseProjectSuite extends PlanTest { test("collapse two nondeterministic, independent projects into one") { val query = testRelation - .select(Rand(10).as('rand)) - .select(Rand(20).as('rand2)) + .select(Rand(10).as(Symbol("rand"))) + .select(Rand(20).as(Symbol("rand2"))) val optimized = Optimize.execute(query.analyze) val correctAnswer = testRelation - .select(Rand(20).as('rand2)).analyze + .select(Rand(20).as(Symbol("rand2"))).analyze comparePlans(optimized, correctAnswer) } test("collapse one nondeterministic, one deterministic, independent projects into one") { val query = testRelation - .select(Rand(10).as('rand), 'a) - .select(('a + 1).as('a_plus_1)) + .select(Rand(10).as(Symbol("rand")), Symbol("a")) + .select((Symbol("a") + 1).as(Symbol("a_plus_1"))) val optimized = Optimize.execute(query.analyze) val correctAnswer = testRelation - .select(('a + 1).as('a_plus_1)).analyze + .select((Symbol("a") + 1).as(Symbol("a_plus_1"))).analyze comparePlans(optimized, correctAnswer) } test("collapse project into aggregate") { val query = testRelation - .groupBy('a, 'b)(('a + 1).as('a_plus_1), 'b) - .select('a_plus_1, ('b + 1).as('b_plus_1)) + .groupBy(Symbol("a"), Symbol("b"))((Symbol("a") + 1).as(Symbol("a_plus_1")), Symbol("b")) + .select(Symbol("a_plus_1"), (Symbol("b") + 1).as(Symbol("b_plus_1"))) val optimized = Optimize.execute(query.analyze) val correctAnswer = testRelation - .groupBy('a, 'b)(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze + .groupBy(Symbol("a"), Symbol("b"))((Symbol("a") + 1).as(Symbol("a_plus_1")), + (Symbol("b") + 1).as(Symbol("b_plus_1"))).analyze comparePlans(optimized, correctAnswer) } test("do not collapse common nondeterministic project and aggregate") { val query = testRelation - .groupBy('a)('a, Rand(10).as('rand)) - .select(('rand + 1).as('rand1), ('rand + 2).as('rand2)) + .groupBy(Symbol("a"))(Symbol("a"), Rand(10).as(Symbol("rand"))) + .select((Symbol("rand") + 1).as(Symbol("rand1")), (Symbol("rand") + 2).as(Symbol("rand2"))) val optimized = Optimize.execute(query.analyze) val correctAnswer = query.analyze @@ -128,8 +130,8 @@ class CollapseProjectSuite extends PlanTest { val metadata = new MetadataBuilder().putLong("key", 1).build() val analyzed = - Project(Seq(Alias('a_with_metadata, "b")()), - Project(Seq(Alias('a, "a_with_metadata")(explicitMetadata = Some(metadata))), + Project(Seq(Alias(Symbol("a_with_metadata"), "b")()), + Project(Seq(Alias(Symbol("a"), "a_with_metadata")(explicitMetadata = Some(metadata))), testRelation.logicalPlan)).analyze require(hasMetadata(analyzed)) @@ -140,34 +142,38 @@ class CollapseProjectSuite extends PlanTest { } test("collapse redundant alias through limit") { - val relation = LocalRelation('a.int, 'b.int) - val query = relation.select('a as 'b).limit(1).select('b as 'c).analyze + val relation = LocalRelation(Symbol("a").int, Symbol("b").int) + val query = relation.select( + Symbol("a") as Symbol("b")).limit(1).select(Symbol("b") as Symbol("c")).analyze val optimized = Optimize.execute(query) - val expected = relation.select('a as 'c).limit(1).analyze + val expected = relation.select(Symbol("a") as Symbol("c")).limit(1).analyze comparePlans(optimized, expected) } test("collapse redundant alias through local limit") { - val relation = LocalRelation('a.int, 'b.int) - val query = LocalLimit(1, relation.select('a as 'b)).select('b as 'c).analyze + val relation = LocalRelation(Symbol("a").int, Symbol("b").int) + val query = LocalLimit( + 1, relation.select(Symbol("a") as Symbol("b"))).select(Symbol("b") as Symbol("c")).analyze val optimized = Optimize.execute(query) - val expected = LocalLimit(1, relation.select('a as 'c)).analyze + val expected = LocalLimit(1, relation.select(Symbol("a") as Symbol("c"))).analyze comparePlans(optimized, expected) } test("collapse redundant alias through repartition") { - val relation = LocalRelation('a.int, 'b.int) - val query = relation.select('a as 'b).repartition(1).select('b as 'c).analyze + val relation = LocalRelation(Symbol("a").int, Symbol("b").int) + val query = relation.select( + Symbol("a") as Symbol("b")).repartition(1).select(Symbol("b") as Symbol("c")).analyze val optimized = Optimize.execute(query) - val expected = relation.select('a as 'c).repartition(1).analyze + val expected = relation.select(Symbol("a") as Symbol("c")).repartition(1).analyze comparePlans(optimized, expected) } test("collapse redundant alias through sample") { - val relation = LocalRelation('a.int, 'b.int) - val query = Sample(0.0, 0.6, false, 11L, relation.select('a as 'b)).select('b as 'c).analyze + val relation = LocalRelation(Symbol("a").int, Symbol("b").int) + val query = Sample(0.0, 0.6, false, 11L, relation.select( + Symbol("a") as Symbol("b"))).select(Symbol("b") as Symbol("c")).analyze val optimized = Optimize.execute(query) - val expected = Sample(0.0, 0.6, false, 11L, relation.select('a as 'c)).analyze + val expected = Sample(0.0, 0.6, false, 11L, relation.select(Symbol("a") as Symbol("c"))).analyze comparePlans(optimized, expected) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala index 8cc8decd65de..6a2a3778395e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala @@ -30,7 +30,7 @@ class CollapseRepartitionSuite extends PlanTest { CollapseRepartition) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int) test("collapse two adjacent coalesces into one") { @@ -110,14 +110,14 @@ class CollapseRepartitionSuite extends PlanTest { // Always respects the top distribute and removes useless repartition val query1 = testRelation .repartition(10) - .distribute('a)(20) + .distribute(Symbol("a"))(20) val query2 = testRelation .repartition(30) - .distribute('a)(20) + .distribute(Symbol("a"))(20) val optimized1 = Optimize.execute(query1.analyze) val optimized2 = Optimize.execute(query2.analyze) - val correctAnswer = testRelation.distribute('a)(20).analyze + val correctAnswer = testRelation.distribute(Symbol("a"))(20).analyze comparePlans(optimized1, correctAnswer) comparePlans(optimized2, correctAnswer) @@ -127,14 +127,14 @@ class CollapseRepartitionSuite extends PlanTest { // Always respects the top distribute and removes useless coalesce below repartition val query1 = testRelation .coalesce(10) - .distribute('a)(20) + .distribute(Symbol("a"))(20) val query2 = testRelation .coalesce(30) - .distribute('a)(20) + .distribute(Symbol("a"))(20) val optimized1 = Optimize.execute(query1.analyze) val optimized2 = Optimize.execute(query2.analyze) - val correctAnswer = testRelation.distribute('a)(20).analyze + val correctAnswer = testRelation.distribute(Symbol("a"))(20).analyze comparePlans(optimized1, correctAnswer) comparePlans(optimized2, correctAnswer) @@ -143,10 +143,10 @@ class CollapseRepartitionSuite extends PlanTest { test("repartition above distribute") { // Always respects the top repartition and removes useless distribute below repartition val query1 = testRelation - .distribute('a)(10) + .distribute(Symbol("a"))(10) .repartition(20) val query2 = testRelation - .distribute('a)(30) + .distribute(Symbol("a"))(30) .repartition(20) val optimized1 = Optimize.execute(query1.analyze) @@ -160,17 +160,17 @@ class CollapseRepartitionSuite extends PlanTest { test("coalesce above distribute") { // Remove useless coalesce above distribute val query1 = testRelation - .distribute('a)(10) + .distribute(Symbol("a"))(10) .coalesce(20) val optimized1 = Optimize.execute(query1.analyze) - val correctAnswer1 = testRelation.distribute('a)(10).analyze + val correctAnswer1 = testRelation.distribute(Symbol("a"))(10).analyze comparePlans(optimized1, correctAnswer1) // No change in this case val query2 = testRelation - .distribute('a)(30) + .distribute(Symbol("a"))(30) .coalesce(20) val optimized2 = Optimize.execute(query2.analyze) @@ -182,15 +182,15 @@ class CollapseRepartitionSuite extends PlanTest { test("collapse two adjacent distributes into one") { // Always respects the top distribute val query1 = testRelation - .distribute('b)(10) - .distribute('a)(20) + .distribute(Symbol("b"))(10) + .distribute(Symbol("a"))(20) val query2 = testRelation - .distribute('b)(30) - .distribute('a)(20) + .distribute(Symbol("b"))(30) + .distribute(Symbol("a"))(20) val optimized1 = Optimize.execute(query1.analyze) val optimized2 = Optimize.execute(query2.analyze) - val correctAnswer = testRelation.distribute('a)(20).analyze + val correctAnswer = testRelation.distribute(Symbol("a"))(20).analyze comparePlans(optimized1, correctAnswer) comparePlans(optimized2, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala index 3b3b4907eea8..53685c25f274 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseWindowSuite.scala @@ -30,7 +30,7 @@ class CollapseWindowSuite extends PlanTest { CollapseWindow) :: Nil } - val testRelation = LocalRelation('a.double, 'b.double, 'c.string) + val testRelation = LocalRelation(Symbol("a").double, Symbol("b").double, Symbol("c").string) val a = testRelation.output(0) val b = testRelation.output(1) val c = testRelation.output(2) @@ -41,28 +41,28 @@ class CollapseWindowSuite extends PlanTest { test("collapse two adjacent windows with the same partition/order") { val query = testRelation - .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1) - .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec1) - .window(Seq(sum(b).as('sum_b)), partitionSpec1, orderSpec1) - .window(Seq(avg(b).as('avg_b)), partitionSpec1, orderSpec1) + .window(Seq(min(a).as(Symbol("min_a"))), partitionSpec1, orderSpec1) + .window(Seq(max(a).as(Symbol("max_a"))), partitionSpec1, orderSpec1) + .window(Seq(sum(b).as(Symbol("sum_b"))), partitionSpec1, orderSpec1) + .window(Seq(avg(b).as(Symbol("avg_b"))), partitionSpec1, orderSpec1) val analyzed = query.analyze val optimized = Optimize.execute(analyzed) assert(analyzed.output === optimized.output) val correctAnswer = testRelation.window(Seq( - min(a).as('min_a), - max(a).as('max_a), - sum(b).as('sum_b), - avg(b).as('avg_b)), partitionSpec1, orderSpec1) + min(a).as(Symbol("min_a")), + max(a).as(Symbol("max_a")), + sum(b).as(Symbol("sum_b")), + avg(b).as(Symbol("avg_b"))), partitionSpec1, orderSpec1) comparePlans(optimized, correctAnswer) } test("Don't collapse adjacent windows with different partitions or orders") { val query1 = testRelation - .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1) - .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec2) + .window(Seq(min(a).as(Symbol("min_a"))), partitionSpec1, orderSpec1) + .window(Seq(max(a).as(Symbol("max_a"))), partitionSpec1, orderSpec2) val optimized1 = Optimize.execute(query1.analyze) val correctAnswer1 = query1.analyze @@ -70,8 +70,8 @@ class CollapseWindowSuite extends PlanTest { comparePlans(optimized1, correctAnswer1) val query2 = testRelation - .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1) - .window(Seq(max(a).as('max_a)), partitionSpec2, orderSpec1) + .window(Seq(min(a).as(Symbol("min_a"))), partitionSpec1, orderSpec1) + .window(Seq(max(a).as(Symbol("max_a"))), partitionSpec2, orderSpec1) val optimized2 = Optimize.execute(query2.analyze) val correctAnswer2 = query2.analyze @@ -81,8 +81,8 @@ class CollapseWindowSuite extends PlanTest { test("Don't collapse adjacent windows with dependent columns") { val query = testRelation - .window(Seq(sum(a).as('sum_a)), partitionSpec1, orderSpec1) - .window(Seq(max('sum_a).as('max_sum_a)), partitionSpec1, orderSpec1) + .window(Seq(sum(a).as(Symbol("sum_a"))), partitionSpec1, orderSpec1) + .window(Seq(max(Symbol("sum_a")).as(Symbol("max_sum_a"))), partitionSpec1, orderSpec1) .analyze val expected = query.analyze @@ -93,7 +93,7 @@ class CollapseWindowSuite extends PlanTest { test("Skip windows with empty window expressions") { val query = testRelation .window(Seq(), partitionSpec1, orderSpec1) - .window(Seq(sum(a).as('sum_a)), partitionSpec1, orderSpec1) + .window(Seq(sum(a).as(Symbol("sum_a"))), partitionSpec1, orderSpec1) val optimized = Optimize.execute(query.analyze) val correctAnswer = query.analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index bfa415afeab9..775e719947bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -41,61 +41,63 @@ class ColumnPruningSuite extends PlanTest { } test("Column pruning for Generate when Generate.unrequiredChildIndex = child.output") { - val input = LocalRelation('a.int, 'b.int, 'c.array(StringType)) + val input = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").array(StringType)) val query = input - .generate(Explode('c), outputNames = "explode" :: Nil) - .select('c, 'explode) + .generate(Explode(Symbol("c")), outputNames = "explode" :: Nil) + .select(Symbol("c"), Symbol("explode")) .analyze val optimized = Optimize.execute(query) val correctAnswer = input - .select('c) - .generate(Explode('c), outputNames = "explode" :: Nil) + .select(Symbol("c")) + .generate(Explode(Symbol("c")), outputNames = "explode" :: Nil) .analyze comparePlans(optimized, correctAnswer) } test("Fill Generate.unrequiredChildIndex if possible") { - val input = LocalRelation('b.array(StringType)) + val input = LocalRelation(Symbol("b").array(StringType)) val query = input - .generate(Explode('b), outputNames = "explode" :: Nil) - .select(('explode + 1).as("result")) + .generate(Explode(Symbol("b")), outputNames = "explode" :: Nil) + .select((Symbol("explode") + 1).as("result")) .analyze val optimized = Optimize.execute(query) val correctAnswer = input - .generate(Explode('b), unrequiredChildIndex = input.output.zipWithIndex.map(_._2), + .generate(Explode(Symbol("b")), unrequiredChildIndex = input.output.zipWithIndex.map(_._2), outputNames = "explode" :: Nil) - .select(('explode + 1).as("result")) + .select((Symbol("explode") + 1).as("result")) .analyze comparePlans(optimized, correctAnswer) } test("Another fill Generate.unrequiredChildIndex if possible") { - val input = LocalRelation('a.int, 'b.int, 'c1.string, 'c2.string) + val input = + LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c1").string, Symbol("c2").string) val query = input - .generate(Explode(CreateArray(Seq('c1, 'c2))), outputNames = "explode" :: Nil) - .select('a, 'c1, 'explode) + .generate( + Explode(CreateArray(Seq(Symbol("c1"), Symbol("c2")))), outputNames = "explode" :: Nil) + .select(Symbol("a"), Symbol("c1"), Symbol("explode")) .analyze val optimized = Optimize.execute(query) val correctAnswer = input - .select('a, 'c1, 'c2) - .generate(Explode(CreateArray(Seq('c1, 'c2))), + .select(Symbol("a"), Symbol("c1"), Symbol("c2")) + .generate(Explode(CreateArray(Seq(Symbol("c1"), Symbol("c2")))), unrequiredChildIndex = Seq(2), outputNames = "explode" :: Nil) .analyze @@ -113,10 +115,10 @@ class ColumnPruningSuite extends PlanTest { withSQLConf(SQLConf.NESTED_PRUNING_ON_EXPRESSIONS.key -> "true") { val structType = StructType.fromDDL("d double, e array, f double, g double, " + "h array>") - val input = LocalRelation('a.int, 'b.int, 'c.struct(structType)) + val input = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").struct(structType)) val generatorOutputs = generatorOutputNames.map(UnresolvedAttribute(_)) - val selectedExprs = Seq(UnresolvedAttribute("a"), 'c.getField("d")) ++ + val selectedExprs = Seq(UnresolvedAttribute("a"), Symbol("c").getField("d")) ++ generatorOutputs val query = @@ -147,82 +149,86 @@ class ColumnPruningSuite extends PlanTest { } runTest( - Explode('c.getField("e")), + Explode(Symbol("c").getField("e")), aliases => Explode($"${aliases(1)}".as("c.e")), - aliases => Seq('c.getField("d").as(aliases(0)), 'c.getField("e").as(aliases(1))), + aliases => + Seq(Symbol("c").getField("d").as(aliases(0)), Symbol("c").getField("e").as(aliases(1))), Seq(2), Seq("explode") ) - runTest(Stack(2 :: 'c.getField("f") :: 'c.getField("g") :: Nil), + runTest(Stack(2 :: Symbol("c").getField("f") :: Symbol("c").getField("g") :: Nil), aliases => Stack(2 :: $"${aliases(1)}".as("c.f") :: $"${aliases(2)}".as("c.g") :: Nil), aliases => Seq( - 'c.getField("d").as(aliases(0)), - 'c.getField("f").as(aliases(1)), - 'c.getField("g").as(aliases(2))), + Symbol("c").getField("d").as(aliases(0)), + Symbol("c").getField("f").as(aliases(1)), + Symbol("c").getField("g").as(aliases(2))), Seq(2, 3), Seq("stack") ) runTest( - PosExplode('c.getField("e")), + PosExplode(Symbol("c").getField("e")), aliases => PosExplode($"${aliases(1)}".as("c.e")), - aliases => Seq('c.getField("d").as(aliases(0)), 'c.getField("e").as(aliases(1))), + aliases => + Seq(Symbol("c").getField("d").as(aliases(0)), Symbol("c").getField("e").as(aliases(1))), Seq(2), Seq("pos", "explode") ) runTest( - Inline('c.getField("h")), + Inline(Symbol("c").getField("h")), aliases => Inline($"${aliases(1)}".as("c.h")), - aliases => Seq('c.getField("d").as(aliases(0)), 'c.getField("h").as(aliases(1))), + aliases => + Seq(Symbol("c").getField("d").as(aliases(0)), Symbol("c").getField("h").as(aliases(1))), Seq(2), Seq("h1", "h2") ) } test("Column pruning for Project on Sort") { - val input = LocalRelation('a.int, 'b.string, 'c.double) + val input = LocalRelation(Symbol("a").int, Symbol("b").string, Symbol("c").double) - val query = input.orderBy('b.asc).select('a).analyze + val query = input.orderBy(Symbol("b").asc).select(Symbol("a")).analyze val optimized = Optimize.execute(query) - val correctAnswer = input.select('a, 'b).orderBy('b.asc).select('a).analyze + val correctAnswer = input.select( + Symbol("a"), Symbol("b")).orderBy(Symbol("b").asc).select(Symbol("a")).analyze comparePlans(optimized, correctAnswer) } test("Column pruning for Expand") { - val input = LocalRelation('a.int, 'b.string, 'c.double) + val input = LocalRelation(Symbol("a").int, Symbol("b").string, Symbol("c").double) val query = Aggregate( - Seq('aa, 'gid), - Seq(sum('c).as("sum")), + Seq(Symbol("aa"), Symbol("gid")), + Seq(sum(Symbol("c")).as("sum")), Expand( Seq( - Seq('a, 'b, 'c, Literal.create(null, StringType), 1), - Seq('a, 'b, 'c, 'a, 2)), - Seq('a, 'b, 'c, 'aa.int, 'gid.int), + Seq(Symbol("a"), Symbol("b"), Symbol("c"), Literal.create(null, StringType), 1), + Seq(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("a"), 2)), + Seq(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("aa").int, Symbol("gid").int), input)).analyze val optimized = Optimize.execute(query) val expected = Aggregate( - Seq('aa, 'gid), - Seq(sum('c).as("sum")), + Seq(Symbol("aa"), Symbol("gid")), + Seq(sum(Symbol("c")).as("sum")), Expand( Seq( - Seq('c, Literal.create(null, StringType), 1), - Seq('c, 'a, 2)), - Seq('c, 'aa.int, 'gid.int), - Project(Seq('a, 'c), + Seq(Symbol("c"), Literal.create(null, StringType), 1), + Seq(Symbol("c"), Symbol("a"), 2)), + Seq(Symbol("c"), Symbol("aa").int, Symbol("gid").int), + Project(Seq(Symbol("a"), Symbol("c")), input))).analyze comparePlans(optimized, expected) } test("Column pruning for ScriptTransformation") { - val input = LocalRelation('a.int, 'b.string, 'c.double) + val input = LocalRelation(Symbol("a").int, Symbol("b").string, Symbol("c").double) val query = ScriptTransformation( - Seq('a, 'b), + Seq(Symbol("a"), Symbol("b")), "func", Seq.empty, input, @@ -231,11 +237,11 @@ class ColumnPruningSuite extends PlanTest { val expected = ScriptTransformation( - Seq('a, 'b), + Seq(Symbol("a"), Symbol("b")), "func", Seq.empty, Project( - Seq('a, 'b), + Seq(Symbol("a"), Symbol("b")), input), null).analyze @@ -243,34 +249,35 @@ class ColumnPruningSuite extends PlanTest { } test("Column pruning on Filter") { - val input = LocalRelation('a.int, 'b.string, 'c.double) - val plan1 = Filter('a > 1, input).analyze + val input = LocalRelation(Symbol("a").int, Symbol("b").string, Symbol("c").double) + val plan1 = Filter(Symbol("a") > 1, input).analyze comparePlans(Optimize.execute(plan1), plan1) - val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze + val query = Project(Symbol("a") :: Nil, Filter(Symbol("c") > Literal(0.0), input)).analyze comparePlans(Optimize.execute(query), query) - val plan2 = Filter('b > 1, Project(Seq('a, 'b), input)).analyze - val expected2 = Project(Seq('a, 'b), Filter('b > 1, input)).analyze + val plan2 = Filter(Symbol("b") > 1, Project(Seq(Symbol("a"), Symbol("b")), input)).analyze + val expected2 = Project(Seq(Symbol("a"), Symbol("b")), Filter(Symbol("b") > 1, input)).analyze comparePlans(Optimize.execute(plan2), expected2) - val plan3 = Project(Seq('a), Filter('b > 1, Project(Seq('a, 'b), input))).analyze - val expected3 = Project(Seq('a), Filter('b > 1, input)).analyze + val plan3 = Project(Seq(Symbol("a")), Filter(Symbol("b") > 1, + Project(Seq(Symbol("a"), Symbol("b")), input))).analyze + val expected3 = Project(Seq(Symbol("a")), Filter(Symbol("b") > 1, input)).analyze comparePlans(Optimize.execute(plan3), expected3) } test("Column pruning on except/intersect/distinct") { - val input = LocalRelation('a.int, 'b.string, 'c.double) - val query = Project('a :: Nil, Except(input, input, isAll = false)).analyze + val input = LocalRelation(Symbol("a").int, Symbol("b").string, Symbol("c").double) + val query = Project(Symbol("a") :: Nil, Except(input, input, isAll = false)).analyze comparePlans(Optimize.execute(query), query) - val query2 = Project('a :: Nil, Intersect(input, input, isAll = false)).analyze + val query2 = Project(Symbol("a") :: Nil, Intersect(input, input, isAll = false)).analyze comparePlans(Optimize.execute(query2), query2) - val query3 = Project('a :: Nil, Distinct(input)).analyze + val query3 = Project(Symbol("a") :: Nil, Distinct(input)).analyze comparePlans(Optimize.execute(query3), query3) } test("Column pruning on Project") { - val input = LocalRelation('a.int, 'b.string, 'c.double) - val query = Project('a :: Nil, Project(Seq('a, 'b), input)).analyze - val expected = Project(Seq('a), input).analyze + val input = LocalRelation(Symbol("a").int, Symbol("b").string, Symbol("c").double) + val query = Project(Symbol("a") :: Nil, Project(Seq(Symbol("a"), Symbol("b")), input)).analyze + val expected = Project(Seq(Symbol("a")), input).analyze comparePlans(Optimize.execute(query), expected) } @@ -291,140 +298,151 @@ class ColumnPruningSuite extends PlanTest { } test("column pruning for group") { - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) val originalQuery = testRelation - .groupBy('a)('a, count('b)) - .select('a) + .groupBy(Symbol("a"))(Symbol("a"), count(Symbol("b"))) + .select(Symbol("a")) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a) - .groupBy('a)('a).analyze + .select(Symbol("a")) + .groupBy(Symbol("a"))(Symbol("a")).analyze comparePlans(optimized, correctAnswer) } test("column pruning for group with alias") { - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) val originalQuery = testRelation - .groupBy('a)('a as 'c, count('b)) - .select('c) + .groupBy(Symbol("a"))(Symbol("a") as Symbol("c"), count(Symbol("b"))) + .select(Symbol("c")) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a) - .groupBy('a)('a as 'c).analyze + .select(Symbol("a")) + .groupBy(Symbol("a"))(Symbol("a") as Symbol("c")).analyze comparePlans(optimized, correctAnswer) } test("column pruning for Project(ne, Limit)") { - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) val originalQuery = testRelation - .select('a, 'b) + .select(Symbol("a"), Symbol("b")) .limit(2) - .select('a) + .select(Symbol("a")) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a) + .select(Symbol("a")) .limit(2).analyze comparePlans(optimized, correctAnswer) } test("push down project past sort") { - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val x = testRelation.subquery('x) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) + val x = testRelation.subquery(Symbol("x")) // push down valid val originalQuery = { - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('a) + x.select(Symbol("a"), Symbol("b")) + .sortBy(SortOrder(Symbol("a"), Ascending)) + .select(Symbol("a")) } val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = - x.select('a) - .sortBy(SortOrder('a, Ascending)).analyze + x.select(Symbol("a")) + .sortBy(SortOrder(Symbol("a"), Ascending)).analyze comparePlans(optimized, correctAnswer) // push down invalid val originalQuery1 = { - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b) + x.select(Symbol("a"), Symbol("b")) + .sortBy(SortOrder(Symbol("a"), Ascending)) + .select(Symbol("b")) } val optimized1 = Optimize.execute(originalQuery1.analyze) val correctAnswer1 = - x.select('a, 'b) - .sortBy(SortOrder('a, Ascending)) - .select('b).analyze + x.select(Symbol("a"), Symbol("b")) + .sortBy(SortOrder(Symbol("a"), Ascending)) + .select(Symbol("b")).analyze comparePlans(optimized1, correctAnswer1) } test("Column pruning on Window with useless aggregate functions") { - val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) - val winSpec = windowSpec('a :: Nil, 'd.asc :: Nil, UnspecifiedFrame) - val winExpr = windowExpr(count('d), winSpec) - - val originalQuery = input.groupBy('a, 'c, 'd)('a, 'c, 'd, winExpr.as('window)).select('a, 'c) - val correctAnswer = input.select('a, 'c, 'd).groupBy('a, 'c, 'd)('a, 'c).analyze + val input = + LocalRelation(Symbol("a").int, Symbol("b").string, Symbol("c").double, Symbol("d").int) + val winSpec = windowSpec(Symbol("a") :: Nil, Symbol("d").asc :: Nil, UnspecifiedFrame) + val winExpr = windowExpr(count(Symbol("d")), winSpec) + + val originalQuery = input.groupBy(Symbol("a"), Symbol("c"), Symbol("d"))( + Symbol("a"), Symbol("c"), Symbol("d"), + winExpr.as(Symbol("window"))).select(Symbol("a"), Symbol("c")) + val correctAnswer = input.select( + Symbol("a"), Symbol("c"), Symbol("d")).groupBy(Symbol("a"), + Symbol("c"), Symbol("d"))(Symbol("a"), Symbol("c")).analyze val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, correctAnswer) } test("Column pruning on Window with selected agg expressions") { - val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) - val winSpec = windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame) - val winExpr = windowExpr(count('b), winSpec) + val input = + LocalRelation(Symbol("a").int, Symbol("b").string, Symbol("c").double, Symbol("d").int) + val winSpec = windowSpec(Symbol("a") :: Nil, Symbol("b").asc :: Nil, UnspecifiedFrame) + val winExpr = windowExpr(count(Symbol("b")), winSpec) val originalQuery = - input.select('a, 'b, 'c, 'd, winExpr.as('window)).where('window > 1).select('a, 'c) + input.select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), + winExpr.as(Symbol("window"))).where(Symbol("window") > 1).select(Symbol("a"), Symbol("c")) val correctAnswer = - input.select('a, 'b, 'c) - .window(winExpr.as('window) :: Nil, 'a :: Nil, 'b.asc :: Nil) - .where('window > 1).select('a, 'c).analyze + input.select(Symbol("a"), Symbol("b"), Symbol("c")) + .window(winExpr.as(Symbol("window")) :: Nil, Symbol("a") :: Nil, Symbol("b").asc :: Nil) + .where(Symbol("window") > 1).select(Symbol("a"), Symbol("c")).analyze val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, correctAnswer) } test("Column pruning on Window in select") { - val input = LocalRelation('a.int, 'b.string, 'c.double, 'd.int) - val winSpec = windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame) - val winExpr = windowExpr(count('b), winSpec) - - val originalQuery = input.select('a, 'b, 'c, 'd, winExpr.as('window)).select('a, 'c) - val correctAnswer = input.select('a, 'c).analyze + val input = + LocalRelation(Symbol("a").int, Symbol("b").string, Symbol("c").double, Symbol("d").int) + val winSpec = windowSpec(Symbol("a") :: Nil, Symbol("b").asc :: Nil, UnspecifiedFrame) + val winExpr = windowExpr(count(Symbol("b")), winSpec) + + val originalQuery = input.select( + Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), + winExpr.as(Symbol("window"))).select(Symbol("a"), Symbol("c")) + val correctAnswer = input.select(Symbol("a"), Symbol("c")).analyze val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, correctAnswer) } test("Column pruning on Union") { - val input1 = LocalRelation('a.int, 'b.string, 'c.double) - val input2 = LocalRelation('c.int, 'd.string, 'e.double) - val query = Project('b :: Nil, Union(input1 :: input2 :: Nil)).analyze - val expected = Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil).analyze + val input1 = LocalRelation(Symbol("a").int, Symbol("b").string, Symbol("c").double) + val input2 = LocalRelation(Symbol("c").int, Symbol("d").string, Symbol("e").double) + val query = Project(Symbol("b") :: Nil, Union(input1 :: input2 :: Nil)).analyze + val expected = Union( + Project(Symbol("b") :: Nil, input1) :: Project(Symbol("d") :: Nil, input2) :: Nil).analyze comparePlans(Optimize.execute(query), expected) } test("Remove redundant projects in column pruning rule") { - val input = LocalRelation('key.int, 'value.string) + val input = LocalRelation(Symbol("key").int, Symbol("value").string) val query = Project(Seq($"x.key", $"y.key"), @@ -447,33 +465,34 @@ class ColumnPruningSuite extends PlanTest { private val func = identity[Iterator[OtherTuple]] _ test("Column pruning on MapPartitions") { - val input = LocalRelation('_1.int, '_2.int, 'c.int) + val input = LocalRelation(Symbol("_1").int, Symbol("_2").int, Symbol("c").int) val plan1 = MapPartitions(func, input) val correctAnswer1 = - MapPartitions(func, Project(Seq('_1, '_2), input)).analyze + MapPartitions(func, Project(Seq(Symbol("_1"), Symbol("_2")), input)).analyze comparePlans(Optimize.execute(plan1.analyze), correctAnswer1) } test("push project down into sample") { - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val x = testRelation.subquery('x) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) + val x = testRelation.subquery(Symbol("x")) - val query1 = Sample(0.0, 0.6, false, 11L, x).select('a) + val query1 = Sample(0.0, 0.6, false, 11L, x).select(Symbol("a")) val optimized1 = Optimize.execute(query1.analyze) - val expected1 = Sample(0.0, 0.6, false, 11L, x.select('a)) + val expected1 = Sample(0.0, 0.6, false, 11L, x.select(Symbol("a"))) comparePlans(optimized1, expected1.analyze) - val query2 = Sample(0.0, 0.6, false, 11L, x).select('a as 'aa) + val query2 = Sample(0.0, 0.6, false, 11L, x).select(Symbol("a") as Symbol("aa")) val optimized2 = Optimize.execute(query2.analyze) - val expected2 = Sample(0.0, 0.6, false, 11L, x.select('a as 'aa)) + val expected2 = Sample(0.0, 0.6, false, 11L, x.select(Symbol("a") as Symbol("aa"))) comparePlans(optimized2, expected2.analyze) } test("SPARK-24696 ColumnPruning rule fails to remove extra Project") { - val input = LocalRelation('key.int, 'value.string) - val query = input.select('key).where(rand(0L) > 0.5).where('key < 10).analyze + val input = LocalRelation(Symbol("key").int, Symbol("value").string) + val query = input.select(Symbol("key")).where(rand(0L) > 0.5).where(Symbol("key") < 10).analyze val optimized = Optimize.execute(query) - val expected = input.where(rand(0L) > 0.5).where('key < 10).select('key).analyze + val expected = input.where( + rand(0L) > 0.5).where(Symbol("key") < 10).select(Symbol("key")).analyze comparePlans(optimized, expected) } // todo: add more tests for column pruning diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 11f908ac180b..f41a61877cf6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -55,14 +55,14 @@ class CombiningLimitsSuite extends PlanTest { test("limits: combines two limits") { val originalQuery = testRelation - .select('a) + .select(Symbol("a")) .limit(10) .limit(5) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a) + .select(Symbol("a")) .limit(5).analyze comparePlans(optimized, correctAnswer) @@ -71,7 +71,7 @@ class CombiningLimitsSuite extends PlanTest { test("limits: combines three limits") { val originalQuery = testRelation - .select('a) + .select(Symbol("a")) .limit(2) .limit(7) .limit(5) @@ -79,7 +79,7 @@ class CombiningLimitsSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a) + .select(Symbol("a")) .limit(2).analyze comparePlans(optimized, correctAnswer) @@ -88,15 +88,15 @@ class CombiningLimitsSuite extends PlanTest { test("limits: combines two limits after ColumnPruning") { val originalQuery = testRelation - .select('a) + .select(Symbol("a")) .limit(2) - .select('a) + .select(Symbol("a")) .limit(5) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a) + .select(Symbol("a")) .limit(2).analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index ae644c111074..5a80fcd8fcc8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -38,18 +38,18 @@ class ConstantFoldingSuite extends PlanTest { BooleanSimplification) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) test("eliminate subqueries") { val originalQuery = testRelation - .subquery('y) - .select('a) + .subquery(Symbol("y")) + .select(Symbol("a")) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a.attr) + .select(Symbol("a").attr) .analyze comparePlans(optimized, correctAnswer) @@ -93,20 +93,20 @@ class ConstantFoldingSuite extends PlanTest { val originalQuery = testRelation .select( - Literal(2) + Literal(3) + 'a as Symbol("c1"), - 'a + Literal(2) + Literal(3) as Symbol("c2"), - Literal(2) * 'a + Literal(4) as Symbol("c3"), - 'a * (Literal(3) + Literal(4)) as Symbol("c4")) + Literal(2) + Literal(3) + Symbol("a") as Symbol("c1"), + Symbol("a") + Literal(2) + Literal(3) as Symbol("c2"), + Literal(2) * Symbol("a") + Literal(4) as Symbol("c3"), + Symbol("a") * (Literal(3) + Literal(4)) as Symbol("c4")) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select( - Literal(5) + 'a as Symbol("c1"), - 'a + Literal(2) + Literal(3) as Symbol("c2"), - Literal(2) * 'a + Literal(4) as Symbol("c3"), - 'a * Literal(7) as Symbol("c4")) + Literal(5) + Symbol("a") as Symbol("c1"), + Symbol("a") + Literal(2) + Literal(3) as Symbol("c2"), + Literal(2) * Symbol("a") + Literal(4) as Symbol("c3"), + Symbol("a") * Literal(7) as Symbol("c4")) .analyze comparePlans(optimized, correctAnswer) @@ -117,20 +117,20 @@ class ConstantFoldingSuite extends PlanTest { val originalQuery = testRelation .where( - (('a > 1 && Literal(1) === Literal(1)) || - ('a < 10 && Literal(1) === Literal(2)) || - (Literal(1) === Literal(1) && 'b > 1) || - (Literal(1) === Literal(2) && 'b < 10)) && - (('a > 1 || Literal(1) === Literal(1)) && - ('a < 10 || Literal(1) === Literal(2)) && - (Literal(1) === Literal(1) || 'b > 1) && - (Literal(1) === Literal(2) || 'b < 10))) + ((Symbol("a") > 1 && Literal(1) === Literal(1)) || + (Symbol("a") < 10 && Literal(1) === Literal(2)) || + (Literal(1) === Literal(1) && Symbol("b") > 1) || + (Literal(1) === Literal(2) && Symbol("b") < 10)) && + ((Symbol("a") > 1 || Literal(1) === Literal(1)) && + (Symbol("a") < 10 || Literal(1) === Literal(2)) && + (Literal(1) === Literal(1) || Symbol("b") > 1) && + (Literal(1) === Literal(2) || Symbol("b") < 10))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(('a > 1 || 'b > 1) && ('a < 10 && 'b < 10)) + .where((Symbol("a") > 1 || Symbol("b") > 1) && (Symbol("a") < 10 && Symbol("b") < 10)) .analyze comparePlans(optimized, correctAnswer) @@ -140,7 +140,7 @@ class ConstantFoldingSuite extends PlanTest { val originalQuery = testRelation .select( - Cast(Literal("2"), IntegerType) + Literal(3) + 'a as Symbol("c1"), + Cast(Literal("2"), IntegerType) + Literal(3) + Symbol("a") as Symbol("c1"), Coalesce(Seq(Cast(Literal("abc"), IntegerType), Literal(3))) as Symbol("c2")) val optimized = Optimize.execute(originalQuery.analyze) @@ -148,7 +148,7 @@ class ConstantFoldingSuite extends PlanTest { val correctAnswer = testRelation .select( - Literal(5) + 'a as Symbol("c1"), + Literal(5) + Symbol("a") as Symbol("c1"), Literal(3) as Symbol("c2")) .analyze @@ -160,7 +160,7 @@ class ConstantFoldingSuite extends PlanTest { testRelation .select( Rand(5L) + Literal(1) as Symbol("c1"), - sum('a) as Symbol("c2")) + sum(Symbol("a")) as Symbol("c2")) val optimized = Optimize.execute(originalQuery.analyze) @@ -168,7 +168,7 @@ class ConstantFoldingSuite extends PlanTest { testRelation .select( Rand(5L) + Literal(1.0) as Symbol("c1"), - sum('a) as Symbol("c2")) + sum(Symbol("a")) as Symbol("c2")) .analyze comparePlans(optimized, correctAnswer) @@ -176,37 +176,38 @@ class ConstantFoldingSuite extends PlanTest { test("Constant folding test: expressions have null literals") { val originalQuery = testRelation.select( - IsNull(Literal(null)) as 'c1, - IsNotNull(Literal(null)) as 'c2, + IsNull(Literal(null)) as Symbol("c1"), + IsNotNull(Literal(null)) as Symbol("c2"), - UnresolvedExtractValue(Literal.create(null, ArrayType(IntegerType)), 1) as 'c3, + UnresolvedExtractValue(Literal.create(null, ArrayType(IntegerType)), 1) as Symbol("c3"), UnresolvedExtractValue( - Literal.create(Seq(1), ArrayType(IntegerType)), Literal.create(null, IntegerType)) as 'c4, + Literal.create(Seq(1), ArrayType(IntegerType)), + Literal.create(null, IntegerType)) as Symbol("c4"), UnresolvedExtractValue( Literal.create(null, StructType(Seq(StructField("a", IntegerType, true)))), - "a") as 'c5, + "a") as Symbol("c5"), - UnaryMinus(Literal.create(null, IntegerType)) as 'c6, - Cast(Literal(null), IntegerType) as 'c7, - Not(Literal.create(null, BooleanType)) as 'c8, + UnaryMinus(Literal.create(null, IntegerType)) as Symbol("c6"), + Cast(Literal(null), IntegerType) as Symbol("c7"), + Not(Literal.create(null, BooleanType)) as Symbol("c8"), - Add(Literal.create(null, IntegerType), 1) as 'c9, - Add(1, Literal.create(null, IntegerType)) as 'c10, + Add(Literal.create(null, IntegerType), 1) as Symbol("c9"), + Add(1, Literal.create(null, IntegerType)) as Symbol("c10"), - EqualTo(Literal.create(null, IntegerType), 1) as 'c11, - EqualTo(1, Literal.create(null, IntegerType)) as 'c12, + EqualTo(Literal.create(null, IntegerType), 1) as Symbol("c11"), + EqualTo(1, Literal.create(null, IntegerType)) as Symbol("c12"), - new Like(Literal.create(null, StringType), "abc") as 'c13, - new Like("abc", Literal.create(null, StringType)) as 'c14, + new Like(Literal.create(null, StringType), "abc") as Symbol("c13"), + new Like("abc", Literal.create(null, StringType)) as Symbol("c14"), - Upper(Literal.create(null, StringType)) as 'c15, + Upper(Literal.create(null, StringType)) as Symbol("c15"), - Substring(Literal.create(null, StringType), 0, 1) as 'c16, - Substring("abc", Literal.create(null, IntegerType), 1) as 'c17, - Substring("abc", 0, Literal.create(null, IntegerType)) as 'c18, + Substring(Literal.create(null, StringType), 0, 1) as Symbol("c16"), + Substring("abc", Literal.create(null, IntegerType), 1) as Symbol("c17"), + Substring("abc", 0, Literal.create(null, IntegerType)) as Symbol("c18"), - Contains(Literal.create(null, StringType), "abc") as 'c19, - Contains("abc", Literal.create(null, StringType)) as 'c20 + Contains(Literal.create(null, StringType), "abc") as Symbol("c19"), + Contains("abc", Literal.create(null, StringType)) as Symbol("c20") ) val optimized = Optimize.execute(originalQuery.analyze) @@ -214,34 +215,34 @@ class ConstantFoldingSuite extends PlanTest { val correctAnswer = testRelation .select( - Literal(true) as 'c1, - Literal(false) as 'c2, + Literal(true) as Symbol("c1"), + Literal(false) as Symbol("c2"), - Literal.create(null, IntegerType) as 'c3, - Literal.create(null, IntegerType) as 'c4, - Literal.create(null, IntegerType) as 'c5, + Literal.create(null, IntegerType) as Symbol("c3"), + Literal.create(null, IntegerType) as Symbol("c4"), + Literal.create(null, IntegerType) as Symbol("c5"), - Literal.create(null, IntegerType) as 'c6, - Literal.create(null, IntegerType) as 'c7, - Literal.create(null, BooleanType) as 'c8, + Literal.create(null, IntegerType) as Symbol("c6"), + Literal.create(null, IntegerType) as Symbol("c7"), + Literal.create(null, BooleanType) as Symbol("c8"), - Literal.create(null, IntegerType) as 'c9, - Literal.create(null, IntegerType) as 'c10, + Literal.create(null, IntegerType) as Symbol("c9"), + Literal.create(null, IntegerType) as Symbol("c10"), - Literal.create(null, BooleanType) as 'c11, - Literal.create(null, BooleanType) as 'c12, + Literal.create(null, BooleanType) as Symbol("c11"), + Literal.create(null, BooleanType) as Symbol("c12"), - Literal.create(null, BooleanType) as 'c13, - Literal.create(null, BooleanType) as 'c14, + Literal.create(null, BooleanType) as Symbol("c13"), + Literal.create(null, BooleanType) as Symbol("c14"), - Literal.create(null, StringType) as 'c15, + Literal.create(null, StringType) as Symbol("c15"), - Literal.create(null, StringType) as 'c16, - Literal.create(null, StringType) as 'c17, - Literal.create(null, StringType) as 'c18, + Literal.create(null, StringType) as Symbol("c16"), + Literal.create(null, StringType) as Symbol("c17"), + Literal.create(null, StringType) as Symbol("c18"), - Literal.create(null, BooleanType) as 'c19, - Literal.create(null, BooleanType) as 'c20 + Literal.create(null, BooleanType) as Symbol("c19"), + Literal.create(null, BooleanType) as Symbol("c20") ).analyze comparePlans(optimized, correctAnswer) @@ -250,14 +251,14 @@ class ConstantFoldingSuite extends PlanTest { test("Constant folding test: Fold In(v, list) into true or false") { val originalQuery = testRelation - .select('a) + .select(Symbol("a")) .where(In(Literal(1), Seq(Literal(1), Literal(2)))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a) + .select(Symbol("a")) .where(Literal(true)) .analyze @@ -267,7 +268,7 @@ class ConstantFoldingSuite extends PlanTest { test("SPARK-33544: Constant folding test with side effects") { val originalQuery = testRelation - .select('a) + .select(Symbol("a")) .where(Size(CreateArray(Seq(AssertTrue(false)))) > 0) val optimized = Optimize.execute(originalQuery.analyze) @@ -287,14 +288,14 @@ class ConstantFoldingSuite extends PlanTest { test("SPARK-33544: Constant folding test CreateArray") { val originalQuery = testRelation - .select('a) - .where(Size(CreateArray(Seq('a))) > 0) + .select(Symbol("a")) + .where(Size(CreateArray(Seq(Symbol("a")))) > 0) val optimized = OptimizeForCreate.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a) + .select(Symbol("a")) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala index 171ac4e3091c..17e942f90690 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantPropagationSuite.scala @@ -40,12 +40,13 @@ class ConstantPropagationSuite extends PlanTest { BooleanSimplification) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.int.notNull) + val testRelation = + LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int, Symbol("d").int.notNull) - private val columnA = 'a - private val columnB = 'b - private val columnC = 'c - private val columnD = 'd + private val columnA = Symbol("a") + private val columnB = Symbol("b") + private val columnC = Symbol("c") + private val columnD = Symbol("d") test("basic test") { val query = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala index 43579d4c903a..ec5be0685241 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConvertToLocalRelationSuite.scala @@ -39,11 +39,11 @@ class ConvertToLocalRelationSuite extends PlanTest { test("Project on LocalRelation should be turned into a single LocalRelation") { val testRelation = LocalRelation( - LocalRelation('a.int, 'b.int).output, + LocalRelation(Symbol("a").int, Symbol("b").int).output, InternalRow(1, 2) :: InternalRow(4, 5) :: Nil) val correctAnswer = LocalRelation( - LocalRelation('a1.int, 'b1.int).output, + LocalRelation(Symbol("a1").int, Symbol("b1").int).output, InternalRow(1, 3) :: InternalRow(4, 6) :: Nil) val projectOnLocal = testRelation.select( @@ -57,11 +57,11 @@ class ConvertToLocalRelationSuite extends PlanTest { test("Filter on LocalRelation should be turned into a single LocalRelation") { val testRelation = LocalRelation( - LocalRelation('a.int, 'b.int).output, + LocalRelation(Symbol("a").int, Symbol("b").int).output, InternalRow(1, 2) :: InternalRow(4, 5) :: Nil) val correctAnswer = LocalRelation( - LocalRelation('a1.int, 'b1.int).output, + LocalRelation(Symbol("a1").int, Symbol("b1").int).output, InternalRow(1, 3) :: Nil) val filterAndProjectOnLocal = testRelation @@ -75,11 +75,11 @@ class ConvertToLocalRelationSuite extends PlanTest { test("SPARK-27798: Expression reusing output shouldn't override values in local relation") { val testRelation = LocalRelation( - LocalRelation('a.int).output, + LocalRelation(Symbol("a").int).output, InternalRow(1) :: InternalRow(2) :: Nil) val correctAnswer = LocalRelation( - LocalRelation('a.struct('a1.int)).output, + LocalRelation(Symbol("a").struct(Symbol("a1").int)).output, InternalRow(InternalRow(1)) :: InternalRow(InternalRow(2)) :: Nil) val projected = testRelation.select(ExprReuseOutput(UnresolvedAttribute("a")).as("a")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala index 711294ed6192..6736fa3c900e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala @@ -32,19 +32,19 @@ class DecimalAggregatesSuite extends PlanTest { DecimalAggregates) :: Nil } - val testRelation = LocalRelation('a.decimal(2, 1), 'b.decimal(12, 1)) + val testRelation = LocalRelation(Symbol("a").decimal(2, 1), Symbol("b").decimal(12, 1)) test("Decimal Sum Aggregation: Optimized") { - val originalQuery = testRelation.select(sum('a)) + val originalQuery = testRelation.select(sum(Symbol("a"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select(MakeDecimal(sum(UnscaledValue('a)), 12, 1).as("sum(a)")).analyze + .select(MakeDecimal(sum(UnscaledValue(Symbol("a"))), 12, 1).as("sum(a)")).analyze comparePlans(optimized, correctAnswer) } test("Decimal Sum Aggregation: Not Optimized") { - val originalQuery = testRelation.select(sum('b)) + val originalQuery = testRelation.select(sum(Symbol("b"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = originalQuery.analyze @@ -52,16 +52,16 @@ class DecimalAggregatesSuite extends PlanTest { } test("Decimal Average Aggregation: Optimized") { - val originalQuery = testRelation.select(avg('a)) + val originalQuery = testRelation.select(avg(Symbol("a"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select((avg(UnscaledValue('a)) / 10.0).cast(DecimalType(6, 5)).as("avg(a)")).analyze + .select((avg(UnscaledValue(Symbol("a"))) / 10.0).cast(DecimalType(6, 5)).as("avg(a)")).analyze comparePlans(optimized, correctAnswer) } test("Decimal Average Aggregation: Not Optimized") { - val originalQuery = testRelation.select(avg('b)) + val originalQuery = testRelation.select(avg(Symbol("b"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = originalQuery.analyze @@ -69,25 +69,26 @@ class DecimalAggregatesSuite extends PlanTest { } test("Decimal Sum Aggregation over Window: Optimized") { - val spec = windowSpec(Seq('a), Nil, UnspecifiedFrame) - val originalQuery = testRelation.select(windowExpr(sum('a), spec).as('sum_a)) + val spec = windowSpec(Seq(Symbol("a")), Nil, UnspecifiedFrame) + val originalQuery = testRelation.select(windowExpr(sum(Symbol("a")), spec).as(Symbol("sum_a"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a) + .select(Symbol("a")) .window( - Seq(MakeDecimal(windowExpr(sum(UnscaledValue('a)), spec), 12, 1).as('sum_a)), - Seq('a), + Seq(MakeDecimal( + windowExpr(sum(UnscaledValue(Symbol("a"))), spec), 12, 1).as(Symbol("sum_a"))), + Seq(Symbol("a")), Nil) - .select('a, 'sum_a, 'sum_a) - .select('sum_a) + .select(Symbol("a"), Symbol("sum_a"), Symbol("sum_a")) + .select(Symbol("sum_a")) .analyze comparePlans(optimized, correctAnswer) } test("Decimal Sum Aggregation over Window: Not Optimized") { - val spec = windowSpec('b :: Nil, Nil, UnspecifiedFrame) - val originalQuery = testRelation.select(windowExpr(sum('b), spec)) + val spec = windowSpec(Symbol("b") :: Nil, Nil, UnspecifiedFrame) + val originalQuery = testRelation.select(windowExpr(sum(Symbol("b")), spec)) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = originalQuery.analyze @@ -95,25 +96,26 @@ class DecimalAggregatesSuite extends PlanTest { } test("Decimal Average Aggregation over Window: Optimized") { - val spec = windowSpec(Seq('a), Nil, UnspecifiedFrame) - val originalQuery = testRelation.select(windowExpr(avg('a), spec).as('avg_a)) + val spec = windowSpec(Seq(Symbol("a")), Nil, UnspecifiedFrame) + val originalQuery = testRelation.select(windowExpr(avg(Symbol("a")), spec).as(Symbol("avg_a"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a) + .select(Symbol("a")) .window( - Seq((windowExpr(avg(UnscaledValue('a)), spec) / 10.0).cast(DecimalType(6, 5)).as('avg_a)), - Seq('a), + Seq((windowExpr(avg(UnscaledValue(Symbol("a"))), spec) / 10.0) + .cast(DecimalType(6, 5)).as(Symbol("avg_a"))), + Seq(Symbol("a")), Nil) - .select('a, 'avg_a, 'avg_a) - .select('avg_a) + .select(Symbol("a"), Symbol("avg_a"), Symbol("avg_a")) + .select(Symbol("avg_a")) .analyze comparePlans(optimized, correctAnswer) } test("Decimal Average Aggregation over Window: Not Optimized") { - val spec = windowSpec('b :: Nil, Nil, UnspecifiedFrame) - val originalQuery = testRelation.select(windowExpr(avg('b), spec)) + val spec = windowSpec(Symbol("b") :: Nil, Nil, UnspecifiedFrame) + val originalQuery = testRelation.select(windowExpr(avg(Symbol("b")), spec)) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = originalQuery.analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateAggregateFilterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateAggregateFilterSuite.scala index ec9b876f78e1..4bf9163d4ddc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateAggregateFilterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateAggregateFilterSuite.scala @@ -30,44 +30,46 @@ class EliminateAggregateFilterSuite extends PlanTest { Batch("Operator Optimizations", Once, ConstantFolding, EliminateAggregateFilter) :: Nil } - val testRelation = LocalRelation('a.int) + val testRelation = LocalRelation(Symbol("a").int) test("Eliminate Filter always is true") { val query = testRelation - .select(sumDistinct('a, Some(Literal.TrueLiteral)).as('result)) + .select(sumDistinct(Symbol("a"), Some(Literal.TrueLiteral)).as(Symbol("result"))) .analyze val answer = testRelation - .select(sumDistinct('a).as('result)) + .select(sumDistinct(Symbol("a")).as(Symbol("result"))) .analyze comparePlans(Optimize.execute(query), answer) } test("Eliminate Filter is foldable and always is true") { val query = testRelation - .select(countDistinctWithFilter(GreaterThan(Literal(2), Literal(1)), 'a).as('result)) + .select(countDistinctWithFilter( + GreaterThan(Literal(2), Literal(1)), Symbol("a")).as(Symbol("result"))) .analyze val answer = testRelation - .select(countDistinct('a).as('result)) + .select(countDistinct(Symbol("a")).as(Symbol("result"))) .analyze comparePlans(Optimize.execute(query), answer) } test("Eliminate Filter always is false") { val query = testRelation - .select(sumDistinct('a, Some(Literal.FalseLiteral)).as('result)) + .select(sumDistinct(Symbol("a"), Some(Literal.FalseLiteral)).as(Symbol("result"))) .analyze val answer = testRelation - .groupBy()(Literal.create(null, LongType).as('result)) + .groupBy()(Literal.create(null, LongType).as(Symbol("result"))) .analyze comparePlans(Optimize.execute(query), answer) } test("Eliminate Filter is foldable and always is false") { val query = testRelation - .select(countDistinctWithFilter(GreaterThan(Literal(1), Literal(2)), 'a).as('result)) + .select(countDistinctWithFilter( + GreaterThan(Literal(1), Literal(2)), Symbol("a")).as(Symbol("result"))) .analyze val answer = testRelation - .groupBy()(Literal.create(0L, LongType).as('result)) + .groupBy()(Literal.create(0L, LongType).as(Symbol("result"))) .analyze comparePlans(Optimize.execute(query), answer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala index 51c751923e41..00ac2c55258d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateDistinctSuite.scala @@ -30,14 +30,14 @@ class EliminateDistinctSuite extends PlanTest { EliminateDistinct) :: Nil } - val testRelation = LocalRelation('a.int) + val testRelation = LocalRelation(Symbol("a").int) test("Eliminate Distinct in Max") { val query = testRelation - .select(maxDistinct('a).as('result)) + .select(maxDistinct(Symbol("a")).as(Symbol("result"))) .analyze val answer = testRelation - .select(max('a).as('result)) + .select(max(Symbol("a")).as(Symbol("result"))) .analyze assert(query != answer) comparePlans(Optimize.execute(query), answer) @@ -45,10 +45,10 @@ class EliminateDistinctSuite extends PlanTest { test("Eliminate Distinct in Min") { val query = testRelation - .select(minDistinct('a).as('result)) + .select(minDistinct(Symbol("a")).as(Symbol("result"))) .analyze val answer = testRelation - .select(min('a).as('result)) + .select(min(Symbol("a")).as(Symbol("result"))) .analyze assert(query != answer) comparePlans(Optimize.execute(query), answer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala index 157472c2293f..001b25c7d7bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala @@ -42,7 +42,7 @@ class EliminateMapObjectsSuite extends PlanTest { test("SPARK-20254: Remove unnecessary data conversion for primitive array") { val intObjType = ObjectType(classOf[Array[Int]]) - val intInput = LocalRelation('a.array(ArrayType(IntegerType, false))) + val intInput = LocalRelation(Symbol("a").array(ArrayType(IntegerType, false))) val intQuery = intInput.deserialize[Array[Int]].analyze val intOptimized = Optimize.execute(intQuery) val intExpected = DeserializeToObject( @@ -51,7 +51,7 @@ class EliminateMapObjectsSuite extends PlanTest { comparePlans(intOptimized, intExpected) val doubleObjType = ObjectType(classOf[Array[Double]]) - val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false))) + val doubleInput = LocalRelation(Symbol("a").array(ArrayType(DoubleType, false))) val doubleQuery = doubleInput.deserialize[Array[Double]].analyze val doubleOptimized = Optimize.execute(doubleQuery) val doubleExpected = DeserializeToObject( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala index ef38cc076d95..ef1821d4dea3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala @@ -39,22 +39,22 @@ class EliminateSerializationSuite extends PlanTest { implicit private def intEncoder = ExpressionEncoder[Int]() test("back to back serialization") { - val input = LocalRelation('obj.obj(classOf[(Int, Int)])) + val input = LocalRelation(Symbol("obj").obj(classOf[(Int, Int)])) val plan = input.serialize[(Int, Int)].deserialize[(Int, Int)].analyze val optimized = Optimize.execute(plan) - val expected = input.select('obj.as("obj")).analyze + val expected = input.select(Symbol("obj").as("obj")).analyze comparePlans(optimized, expected) } test("back to back serialization with object change") { - val input = LocalRelation('obj.obj(classOf[OtherTuple])) + val input = LocalRelation(Symbol("obj").obj(classOf[OtherTuple])) val plan = input.serialize[OtherTuple].deserialize[(Int, Int)].analyze val optimized = Optimize.execute(plan) comparePlans(optimized, plan) } test("back to back serialization in AppendColumns") { - val input = LocalRelation('obj.obj(classOf[(Int, Int)])) + val input = LocalRelation(Symbol("obj").obj(classOf[(Int, Int)])) val func = (item: (Int, Int)) => item._1 val plan = AppendColumns(func, input.serialize[(Int, Int)]).analyze @@ -70,7 +70,7 @@ class EliminateSerializationSuite extends PlanTest { } test("back to back serialization in AppendColumns with object change") { - val input = LocalRelation('obj.obj(classOf[OtherTuple])) + val input = LocalRelation(Symbol("obj").obj(classOf[OtherTuple])) val func = (item: (Int, Int)) => item._1 val plan = AppendColumns(func, input.serialize[OtherTuple]).analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsBeforeRepartitionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsBeforeRepartitionSuite.scala index 82db174ad41b..709e897c3e45 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsBeforeRepartitionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsBeforeRepartitionSuite.scala @@ -30,8 +30,8 @@ class EliminateSortsBeforeRepartitionSuite extends PlanTest { val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry) val analyzer = new Analyzer(catalog) - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val anotherTestRelation = LocalRelation('d.int, 'e.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) + val anotherTestRelation = LocalRelation(Symbol("d").int, Symbol("e").int) object Optimize extends RuleExecutor[LogicalPlan] { val batches = @@ -47,87 +47,100 @@ class EliminateSortsBeforeRepartitionSuite extends PlanTest { def repartition(plan: LogicalPlan): LogicalPlan = plan.repartition(10) test("sortBy") { - val plan = testRelation.select('a, 'b).sortBy('a.asc, 'b.desc) - val optimizedPlan = testRelation.select('a, 'b) + val plan = + testRelation.select(Symbol("a"), Symbol("b")).sortBy(Symbol("a").asc, Symbol("b").desc) + val optimizedPlan = testRelation.select(Symbol("a"), Symbol("b")) checkRepartitionCases(plan, optimizedPlan) } test("sortBy with projection") { - val plan = testRelation.sortBy('a.asc, 'b.asc).select('a + 1 as "a", 'b + 2 as "b") - val optimizedPlan = testRelation.select('a + 1 as "a", 'b + 2 as "b") + val plan = testRelation.sortBy(Symbol("a").asc, Symbol("b").asc) + .select(Symbol("a") + 1 as "a", Symbol("b") + 2 as "b") + val optimizedPlan = testRelation.select(Symbol("a") + 1 as "a", Symbol("b") + 2 as "b") checkRepartitionCases(plan, optimizedPlan) } test("sortBy with projection and filter") { - val plan = testRelation.sortBy('a.asc, 'b.asc).select('a, 'b).where('a === 10) - val optimizedPlan = testRelation.select('a, 'b).where('a === 10) + val plan = testRelation.sortBy(Symbol("a").asc, Symbol("b").asc) + .select(Symbol("a"), Symbol("b")).where(Symbol("a") === 10) + val optimizedPlan = testRelation.select(Symbol("a"), Symbol("b")).where(Symbol("a") === 10) checkRepartitionCases(plan, optimizedPlan) } test("sortBy with limit") { - val plan = testRelation.sortBy('a.asc, 'b.asc).limit(10) - val optimizedPlan = testRelation.sortBy('a.asc, 'b.asc).limit(10) + val plan = testRelation.sortBy(Symbol("a").asc, Symbol("b").asc).limit(10) + val optimizedPlan = testRelation.sortBy(Symbol("a").asc, Symbol("b").asc).limit(10) checkRepartitionCases(plan, optimizedPlan) } test("sortBy with non-deterministic projection") { - val plan = testRelation.sortBy('a.asc, 'b.asc).select(rand(1), 'a, 'b) - val optimizedPlan = testRelation.sortBy('a.asc, 'b.asc).select(rand(1), 'a, 'b) + val plan = testRelation.sortBy(Symbol("a").asc, Symbol("b").asc) + .select(rand(1), Symbol("a"), Symbol("b")) + val optimizedPlan = testRelation.sortBy(Symbol("a").asc, Symbol("b").asc) + .select(rand(1), Symbol("a"), Symbol("b")) checkRepartitionCases(plan, optimizedPlan) } test("orderBy") { - val plan = testRelation.select('a, 'b).orderBy('a.asc, 'b.asc) - val optimizedPlan = testRelation.select('a, 'b) + val plan = + testRelation.select(Symbol("a"), Symbol("b")).orderBy(Symbol("a").asc, Symbol("b").asc) + val optimizedPlan = testRelation.select(Symbol("a"), Symbol("b")) checkRepartitionCases(plan, optimizedPlan) } test("orderBy with projection") { - val plan = testRelation.orderBy('a.asc, 'b.asc).select('a + 1 as "a", 'b + 2 as "b") - val optimizedPlan = testRelation.select('a + 1 as "a", 'b + 2 as "b") + val plan = testRelation.orderBy(Symbol("a").asc, Symbol("b").asc) + .select(Symbol("a") + 1 as "a", Symbol("b") + 2 as "b") + val optimizedPlan = testRelation.select(Symbol("a") + 1 as "a", Symbol("b") + 2 as "b") checkRepartitionCases(plan, optimizedPlan) } test("orderBy with projection and filter") { - val plan = testRelation.orderBy('a.asc, 'b.asc).select('a, 'b).where('a === 10) - val optimizedPlan = testRelation.select('a, 'b).where('a === 10) + val plan = testRelation.orderBy(Symbol("a").asc, Symbol("b").asc) + .select(Symbol("a"), Symbol("b")).where(Symbol("a") === 10) + val optimizedPlan = testRelation.select(Symbol("a"), Symbol("b")).where(Symbol("a") === 10) checkRepartitionCases(plan, optimizedPlan) } test("orderBy with limit") { - val plan = testRelation.orderBy('a.asc, 'b.asc).limit(10) - val optimizedPlan = testRelation.orderBy('a.asc, 'b.asc).limit(10) + val plan = testRelation.orderBy(Symbol("a").asc, Symbol("b").asc).limit(10) + val optimizedPlan = testRelation.orderBy(Symbol("a").asc, Symbol("b").asc).limit(10) checkRepartitionCases(plan, optimizedPlan) } test("orderBy with non-deterministic projection") { - val plan = testRelation.orderBy('a.asc, 'b.asc).select(rand(1), 'a, 'b) - val optimizedPlan = testRelation.orderBy('a.asc, 'b.asc).select(rand(1), 'a, 'b) + val plan = testRelation.orderBy(Symbol("a").asc, Symbol("b").asc) + .select(rand(1), Symbol("a"), Symbol("b")) + val optimizedPlan = testRelation.orderBy(Symbol("a").asc, Symbol("b").asc) + .select(rand(1), Symbol("a"), Symbol("b")) checkRepartitionCases(plan, optimizedPlan) } test("additional coalesce and sortBy") { - val plan = testRelation.sortBy('a.asc, 'b.asc).coalesce(1) + val plan = testRelation.sortBy(Symbol("a").asc, Symbol("b").asc).coalesce(1) val optimizedPlan = testRelation.coalesce(1) checkRepartitionCases(plan, optimizedPlan) } test("additional projection, repartition and sortBy") { - val plan = testRelation.sortBy('a.asc, 'b.asc).repartition(100).select('a + 1 as "a") - val optimizedPlan = testRelation.repartition(100).select('a + 1 as "a") + val plan = testRelation.sortBy(Symbol("a").asc, Symbol("b").asc).repartition(100) + .select(Symbol("a") + 1 as "a") + val optimizedPlan = testRelation.repartition(100).select(Symbol("a") + 1 as "a") checkRepartitionCases(plan, optimizedPlan) } test("additional filter, distribute and sortBy") { - val plan = testRelation.sortBy('a.asc, 'b.asc).distribute('a)(2).where('a === 10) - val optimizedPlan = testRelation.distribute('a)(2).where('a === 10) + val plan = testRelation.sortBy(Symbol("a").asc, Symbol("b").asc) + .distribute(Symbol("a"))(2).where(Symbol("a") === 10) + val optimizedPlan = testRelation.distribute(Symbol("a"))(2).where(Symbol("a") === 10) checkRepartitionCases(plan, optimizedPlan) } test("join") { - val plan = testRelation.sortBy('a.asc, 'b.asc).distribute('a)(2).where('a === 10) - val optimizedPlan = testRelation.distribute('a)(2).where('a === 10) - val anotherPlan = anotherTestRelation.select('d) + val plan = testRelation.sortBy(Symbol("a").asc, Symbol("b").asc) + .distribute(Symbol("a"))(2).where(Symbol("a") === 10) + val optimizedPlan = testRelation.distribute(Symbol("a"))(2).where(Symbol("a") === 10) + val anotherPlan = anotherTestRelation.select(Symbol("d")) val joinPlan = plan.join(anotherPlan) val optimizedJoinPlan = optimize(joinPlan) val correctJoinPlan = analyze(optimizedPlan.join(anotherPlan)) @@ -135,11 +148,12 @@ class EliminateSortsBeforeRepartitionSuite extends PlanTest { } test("aggregate") { - val plan = testRelation.sortBy('a.asc, 'b.asc).distribute('a)(2).where('a === 10) - val optimizedPlan = testRelation.distribute('a)(2).where('a === 10) - val aggPlan = plan.groupBy('a)(sum('b)) + val plan = testRelation.sortBy(Symbol("a").asc, Symbol("b").asc) + .distribute(Symbol("a"))(2).where(Symbol("a") === 10) + val optimizedPlan = testRelation.distribute(Symbol("a"))(2).where(Symbol("a") === 10) + val aggPlan = plan.groupBy(Symbol("a"))(sum(Symbol("b"))) val optimizedAggPlan = optimize(aggPlan) - val correctAggPlan = analyze(optimizedPlan.groupBy('a)(sum('b))) + val correctAggPlan = analyze(optimizedPlan.groupBy(Symbol("a"))(sum(Symbol("b")))) comparePlans(optimizedAggPlan, correctAggPlan) } @@ -151,15 +165,17 @@ class EliminateSortsBeforeRepartitionSuite extends PlanTest { comparePlans(optimizedPlanWithRepartition, correctPlanWithRepartition) // can remove sortBy before repartition with sortBy - val planWithRepartitionAndSortBy = planWithRepartition.sortBy('a.asc) + val planWithRepartitionAndSortBy = planWithRepartition.sortBy(Symbol("a").asc) val optimizedPlanWithRepartitionAndSortBy = optimize(planWithRepartitionAndSortBy) - val correctPlanWithRepartitionAndSortBy = analyze(repartition(optimizedPlan).sortBy('a.asc)) + val correctPlanWithRepartitionAndSortBy = + analyze(repartition(optimizedPlan).sortBy(Symbol("a").asc)) comparePlans(optimizedPlanWithRepartitionAndSortBy, correctPlanWithRepartitionAndSortBy) // can remove sortBy before repartition with orderBy - val planWithRepartitionAndOrderBy = planWithRepartition.orderBy('a.asc) + val planWithRepartitionAndOrderBy = planWithRepartition.orderBy(Symbol("a").asc) val optimizedPlanWithRepartitionAndOrderBy = optimize(planWithRepartitionAndOrderBy) - val correctPlanWithRepartitionAndOrderBy = analyze(repartition(optimizedPlan).orderBy('a.asc)) + val correctPlanWithRepartitionAndOrderBy = + analyze(repartition(optimizedPlan).orderBy(Symbol("a").asc)) comparePlans(optimizedPlanWithRepartitionAndOrderBy, correctPlanWithRepartitionAndOrderBy) } @@ -173,17 +189,17 @@ class EliminateSortsBeforeRepartitionSuite extends PlanTest { } class EliminateSortsBeforeRepartitionByExprsSuite extends EliminateSortsBeforeRepartitionSuite { - override def repartition(plan: LogicalPlan): LogicalPlan = plan.distribute('a)(10) + override def repartition(plan: LogicalPlan): LogicalPlan = plan.distribute(Symbol("a"))(10) test("sortBy before repartition with non-deterministic expressions") { - val plan = testRelation.sortBy('a.asc, 'b.asc).limit(10) - val planWithRepartition = plan.distribute(rand(1).asc, 'a.asc)(20) + val plan = testRelation.sortBy(Symbol("a").asc, Symbol("b").asc).limit(10) + val planWithRepartition = plan.distribute(rand(1).asc, Symbol("a").asc)(20) checkRepartitionCases(plan = planWithRepartition, optimizedPlan = planWithRepartition) } test("orderBy before repartition with non-deterministic expressions") { - val plan = testRelation.orderBy('a.asc, 'b.asc).limit(10) - val planWithRepartition = plan.distribute(rand(1).asc, 'a.asc)(20) + val plan = testRelation.orderBy(Symbol("a").asc, Symbol("b").asc).limit(10) + val planWithRepartition = plan.distribute(rand(1).asc, Symbol("a").asc)(20) checkRepartitionCases(plan = planWithRepartition, optimizedPlan = planWithRepartition) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index 01ecbd808c25..48840c511867 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -86,9 +86,9 @@ class EliminateSortsSuite extends AnalysisTest { val x = testRelation val analyzer = getAnalyzer - val query = x.orderBy(SortOrder(3, Ascending), 'a.asc) + val query = x.orderBy(SortOrder(3, Ascending), Symbol("a").asc) val optimized = Optimize.execute(analyzer.execute(query)) - val correctAnswer = analyzer.execute(x.orderBy('a.asc)) + val correctAnswer = analyzer.execute(x.orderBy(Symbol("a").asc)) comparePlans(optimized, correctAnswer) } @@ -97,11 +97,13 @@ class EliminateSortsSuite extends AnalysisTest { test("Remove no-op alias") { val x = testRelation - val query = x.select('a.as('x), Year(CurrentDate()).as('y), 'b) - .orderBy('x.asc, 'y.asc, 'b.desc) + val query = + x.select(Symbol("a").as(Symbol("x")), Year(CurrentDate()).as(Symbol("y")), Symbol("b")) + .orderBy(Symbol("x").asc, Symbol("y").asc, Symbol("b").desc) val optimized = Optimize.execute(analyzer.execute(query)) val correctAnswer = analyzer.execute( - x.select('a.as('x), Year(CurrentDate()).as('y), 'b).orderBy('x.asc, 'b.desc)) + x.select(Symbol("a").as(Symbol("x")), Year(CurrentDate()).as(Symbol("y")), Symbol("b")) + .orderBy(Symbol("x").asc, Symbol("b").desc)) comparePlans(optimized, correctAnswer) } @@ -114,72 +116,87 @@ class EliminateSortsSuite extends AnalysisTest { } test("SPARK-33183: remove redundant sort by") { - val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val unnecessaryReordered = orderedPlan.limit(2).select('a).sortBy('a.asc, 'b.desc_nullsFirst) + val orderedPlan = testRelation.select(Symbol("a"), Symbol("b")) + .orderBy(Symbol("a").asc, Symbol("b").desc_nullsFirst) + val unnecessaryReordered = orderedPlan.limit(2).select(Symbol("a")) + .sortBy(Symbol("a").asc, Symbol("b").desc_nullsFirst) val optimized = Optimize.execute(unnecessaryReordered.analyze) - val correctAnswer = orderedPlan.limit(2).select('a).analyze + val correctAnswer = orderedPlan.limit(2).select(Symbol("a")).analyze comparePlans(optimized, correctAnswer) } test("SPARK-33183: remove all redundant local sorts") { - val orderedPlan = testRelation.sortBy('a.asc).orderBy('a.asc).sortBy('a.asc) + val orderedPlan = testRelation.sortBy(Symbol("a").asc) + .orderBy(Symbol("a").asc).sortBy(Symbol("a").asc) val optimized = Optimize.execute(orderedPlan.analyze) - val correctAnswer = testRelation.orderBy('a.asc).analyze + val correctAnswer = testRelation.orderBy(Symbol("a").asc).analyze comparePlans(optimized, correctAnswer) } test("SPARK-33183: should not remove global sort") { - val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val reordered = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val orderedPlan = testRelation.select(Symbol("a"), Symbol("b")) + .orderBy(Symbol("a").asc, Symbol("b").desc_nullsFirst) + val reordered = orderedPlan.limit(2).select(Symbol("a")) + .orderBy(Symbol("a").asc, Symbol("b").desc_nullsFirst) val optimized = Optimize.execute(reordered.analyze) val correctAnswer = reordered.analyze comparePlans(optimized, correctAnswer) } test("do not remove sort if the order is different") { - val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) - val reorderedDifferently = orderedPlan.limit(2).select('a).orderBy('a.asc, 'b.desc) + val orderedPlan = testRelation.select(Symbol("a"), Symbol("b")) + .orderBy(Symbol("a").asc, Symbol("b").desc_nullsFirst) + val reorderedDifferently = orderedPlan.limit(2).select(Symbol("a")) + .orderBy(Symbol("a").asc, Symbol("b").desc) val optimized = Optimize.execute(reorderedDifferently.analyze) val correctAnswer = reorderedDifferently.analyze comparePlans(optimized, correctAnswer) } test("SPARK-33183: remove top level local sort with filter operators") { - val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) - val filteredAndReordered = orderedPlan.where('a > Literal(10)).sortBy('a.asc, 'b.desc) + val orderedPlan = testRelation.select(Symbol("a"), Symbol("b")) + .orderBy(Symbol("a").asc, Symbol("b").desc) + val filteredAndReordered = orderedPlan.where(Symbol("a") > Literal(10)) + .sortBy(Symbol("a").asc, Symbol("b").desc) val optimized = Optimize.execute(filteredAndReordered.analyze) - val correctAnswer = orderedPlan.where('a > Literal(10)).analyze + val correctAnswer = orderedPlan.where(Symbol("a") > Literal(10)).analyze comparePlans(optimized, correctAnswer) } test("SPARK-33183: keep top level global sort with filter operators") { - val projectPlan = testRelation.select('a, 'b) - val orderedPlan = projectPlan.orderBy('a.asc, 'b.desc) - val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc) + val projectPlan = testRelation.select(Symbol("a"), Symbol("b")) + val orderedPlan = projectPlan.orderBy(Symbol("a").asc, Symbol("b").desc) + val filteredAndReordered = orderedPlan.where(Symbol("a") > Literal(10)) + .orderBy(Symbol("a").asc, Symbol("b").desc) val optimized = Optimize.execute(filteredAndReordered.analyze) - val correctAnswer = projectPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc).analyze + val correctAnswer = projectPlan.where(Symbol("a") > Literal(10)) + .orderBy(Symbol("a").asc, Symbol("b").desc).analyze comparePlans(optimized, correctAnswer) } test("SPARK-33183: limits should not affect order for local sort") { - val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) - val filteredAndReordered = orderedPlan.limit(Literal(10)).sortBy('a.asc, 'b.desc) + val orderedPlan = testRelation.select(Symbol("a"), Symbol("b")) + .orderBy(Symbol("a").asc, Symbol("b").desc) + val filteredAndReordered = orderedPlan.limit(Literal(10)) + .sortBy(Symbol("a").asc, Symbol("b").desc) val optimized = Optimize.execute(filteredAndReordered.analyze) val correctAnswer = orderedPlan.limit(Literal(10)).analyze comparePlans(optimized, correctAnswer) } test("SPARK-33183: should not remove global sort with limit operators") { - val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) - val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc) + val orderedPlan = testRelation.select(Symbol("a"), Symbol("b")) + .orderBy(Symbol("a").asc, Symbol("b").desc) + val filteredAndReordered = orderedPlan.limit(Literal(10)) + .orderBy(Symbol("a").asc, Symbol("b").desc) val optimized = Optimize.execute(filteredAndReordered.analyze) val correctAnswer = filteredAndReordered.analyze comparePlans(optimized, correctAnswer) } test("different sorts are not simplified if limit is in between") { - val orderedPlan = testRelation.select('a, 'b).orderBy('b.desc).limit(Literal(10)) - .orderBy('a.asc) + val orderedPlan = testRelation.select(Symbol("a"), Symbol("b")) + .orderBy(Symbol("b").desc).limit(Literal(10)).orderBy(Symbol("a").asc) val optimized = Optimize.execute(orderedPlan.analyze) val correctAnswer = orderedPlan.analyze comparePlans(optimized, correctAnswer) @@ -187,18 +204,18 @@ class EliminateSortsSuite extends AnalysisTest { test("SPARK-33183: should not remove global sort with range operator") { val inputPlan = Range(1L, 1000L, 1, 10) - val orderedPlan = inputPlan.orderBy('id.asc) + val orderedPlan = inputPlan.orderBy(Symbol("id").asc) val optimized = Optimize.execute(orderedPlan.analyze) val correctAnswer = orderedPlan.analyze comparePlans(optimized, correctAnswer) - val reversedPlan = inputPlan.orderBy('id.desc) + val reversedPlan = inputPlan.orderBy(Symbol("id").desc) val reversedOptimized = Optimize.execute(reversedPlan.analyze) val reversedCorrectAnswer = reversedPlan.analyze comparePlans(reversedOptimized, reversedCorrectAnswer) val negativeStepInputPlan = Range(10L, 1L, -1, 10) - val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc) + val negativeStepOrderedPlan = negativeStepInputPlan.orderBy(Symbol("id").desc) val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze) val negativeStepCorrectAnswer = negativeStepOrderedPlan.analyze comparePlans(negativeStepOptimized, negativeStepCorrectAnswer) @@ -206,50 +223,55 @@ class EliminateSortsSuite extends AnalysisTest { test("SPARK-33183: remove local sort with range operator") { val inputPlan = Range(1L, 1000L, 1, 10) - val orderedPlan = inputPlan.sortBy('id.asc) + val orderedPlan = inputPlan.sortBy(Symbol("id").asc) val optimized = Optimize.execute(orderedPlan.analyze) val correctAnswer = inputPlan.analyze comparePlans(optimized, correctAnswer) } test("sort should not be removed when there is a node which doesn't guarantee any order") { - val orderedPlan = testRelation.select('a, 'b) - val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc) + val orderedPlan = testRelation.select(Symbol("a"), Symbol("b")) + val groupedAndResorted = + orderedPlan.groupBy(Symbol("a"))(sum(Symbol("a"))).orderBy(Symbol("a").asc) val optimized = Optimize.execute(groupedAndResorted.analyze) val correctAnswer = groupedAndResorted.analyze comparePlans(optimized, correctAnswer) } test("remove two consecutive sorts") { - val orderedTwice = testRelation.orderBy('a.asc).orderBy('b.desc) + val orderedTwice = testRelation.orderBy(Symbol("a").asc).orderBy(Symbol("b").desc) val optimized = Optimize.execute(orderedTwice.analyze) - val correctAnswer = testRelation.orderBy('b.desc).analyze + val correctAnswer = testRelation.orderBy(Symbol("b").desc).analyze comparePlans(optimized, correctAnswer) } test("remove sorts separated by Filter/Project operators") { - val orderedTwiceWithProject = testRelation.orderBy('a.asc).select('b).orderBy('b.desc) + val orderedTwiceWithProject = testRelation.orderBy(Symbol("a").asc) + .select(Symbol("b")).orderBy(Symbol("b").desc) val optimizedWithProject = Optimize.execute(orderedTwiceWithProject.analyze) - val correctAnswerWithProject = testRelation.select('b).orderBy('b.desc).analyze + val correctAnswerWithProject = + testRelation.select(Symbol("b")).orderBy(Symbol("b").desc).analyze comparePlans(optimizedWithProject, correctAnswerWithProject) - val orderedTwiceWithFilter = - testRelation.orderBy('a.asc).where('b > Literal(0)).orderBy('b.desc) + val orderedTwiceWithFilter = testRelation.orderBy(Symbol("a").asc) + .where(Symbol("b") > Literal(0)).orderBy(Symbol("b").desc) val optimizedWithFilter = Optimize.execute(orderedTwiceWithFilter.analyze) - val correctAnswerWithFilter = testRelation.where('b > Literal(0)).orderBy('b.desc).analyze + val correctAnswerWithFilter = + testRelation.where(Symbol("b") > Literal(0)).orderBy(Symbol("b").desc).analyze comparePlans(optimizedWithFilter, correctAnswerWithFilter) - val orderedTwiceWithBoth = - testRelation.orderBy('a.asc).select('b).where('b > Literal(0)).orderBy('b.desc) + val orderedTwiceWithBoth = testRelation.orderBy(Symbol("a").asc).select(Symbol("b")) + .where(Symbol("b") > Literal(0)).orderBy(Symbol("b").desc) val optimizedWithBoth = Optimize.execute(orderedTwiceWithBoth.analyze) - val correctAnswerWithBoth = - testRelation.select('b).where('b > Literal(0)).orderBy('b.desc).analyze + val correctAnswerWithBoth = testRelation.select(Symbol("b")).where(Symbol("b") > Literal(0)) + .orderBy(Symbol("b").desc).analyze comparePlans(optimizedWithBoth, correctAnswerWithBoth) - val orderedThrice = orderedTwiceWithBoth.select(('b + 1).as('c)).orderBy('c.asc) + val orderedThrice = + orderedTwiceWithBoth.select((Symbol("b") + 1).as(Symbol("c"))).orderBy(Symbol("c").asc) val optimizedThrice = Optimize.execute(orderedThrice.analyze) - val correctAnswerThrice = testRelation.select('b).where('b > Literal(0)) - .select(('b + 1).as('c)).orderBy('c.asc).analyze + val correctAnswerThrice = testRelation.select(Symbol("b")).where(Symbol("b") > Literal(0)) + .select((Symbol("b") + 1).as(Symbol("c"))).orderBy(Symbol("c").asc).analyze comparePlans(optimizedThrice, correctAnswerThrice) } @@ -265,37 +287,37 @@ class EliminateSortsSuite extends AnalysisTest { (e : Expression) => bitOr(e), (e : Expression) => bitXor(e) ).foreach(agg => { - val projectPlan = testRelation.select('a, 'b) - val unnecessaryOrderByPlan = projectPlan.orderBy('a.asc, 'b.desc) - val groupByPlan = unnecessaryOrderByPlan.groupBy('a)(agg('b)) + val projectPlan = testRelation.select(Symbol("a"), Symbol("b")) + val unnecessaryOrderByPlan = projectPlan.orderBy(Symbol("a").asc, Symbol("b").desc) + val groupByPlan = unnecessaryOrderByPlan.groupBy(Symbol("a"))(agg(Symbol("b"))) val optimized = Optimize.execute(groupByPlan.analyze) - val correctAnswer = projectPlan.groupBy('a)(agg('b)).analyze + val correctAnswer = projectPlan.groupBy(Symbol("a"))(agg(Symbol("b"))).analyze comparePlans(optimized, correctAnswer) }) } test("remove orderBy in groupBy clause with sum aggs") { - val projectPlan = testRelation.select('a, 'b) - val unnecessaryOrderByPlan = projectPlan.orderBy('a.asc, 'b.desc) - val groupByPlan = unnecessaryOrderByPlan.groupBy('a)(sum('a) + 10 as "sum") + val projectPlan = testRelation.select(Symbol("a"), Symbol("b")) + val unnecessaryOrderByPlan = projectPlan.orderBy(Symbol("a").asc, Symbol("b").desc) + val groupByPlan = unnecessaryOrderByPlan.groupBy(Symbol("a"))(sum(Symbol("a")) + 10 as "sum") val optimized = Optimize.execute(groupByPlan.analyze) - val correctAnswer = projectPlan.groupBy('a)(sum('a) + 10 as "sum").analyze + val correctAnswer = projectPlan.groupBy(Symbol("a"))(sum(Symbol("a")) + 10 as "sum").analyze comparePlans(optimized, correctAnswer) } test("should not remove orderBy in groupBy clause with first aggs") { - val projectPlan = testRelation.select('a, 'b) - val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc) - val groupByPlan = orderByPlan.groupBy('a)(first('a)) + val projectPlan = testRelation.select(Symbol("a"), Symbol("b")) + val orderByPlan = projectPlan.orderBy(Symbol("a").asc, Symbol("b").desc) + val groupByPlan = orderByPlan.groupBy(Symbol("a"))(first(Symbol("a"))) val optimized = Optimize.execute(groupByPlan.analyze) val correctAnswer = groupByPlan.analyze comparePlans(optimized, correctAnswer) } test("should not remove orderBy in groupBy clause with first and count aggs") { - val projectPlan = testRelation.select('a, 'b) - val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc) - val groupByPlan = orderByPlan.groupBy('a)(first('a), count(1)) + val projectPlan = testRelation.select(Symbol("a"), Symbol("b")) + val orderByPlan = projectPlan.orderBy(Symbol("a").asc, Symbol("b").desc) + val groupByPlan = orderByPlan.groupBy(Symbol("a"))(first(Symbol("a")), count(1)) val optimized = Optimize.execute(groupByPlan.analyze) val correctAnswer = groupByPlan.analyze comparePlans(optimized, correctAnswer) @@ -304,67 +326,67 @@ class EliminateSortsSuite extends AnalysisTest { test("should not remove orderBy in groupBy clause with PythonUDF as aggs") { val pythonUdf = PythonUDF("pyUDF", null, IntegerType, Seq.empty, PythonEvalType.SQL_BATCHED_UDF, true) - val projectPlan = testRelation.select('a, 'b) - val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc) - val groupByPlan = orderByPlan.groupBy('a)(pythonUdf) + val projectPlan = testRelation.select(Symbol("a"), Symbol("b")) + val orderByPlan = projectPlan.orderBy(Symbol("a").asc, Symbol("b").desc) + val groupByPlan = orderByPlan.groupBy(Symbol("a"))(pythonUdf) val optimized = Optimize.execute(groupByPlan.analyze) val correctAnswer = groupByPlan.analyze comparePlans(optimized, correctAnswer) } test("should not remove orderBy in groupBy clause with ScalaUDF as aggs") { - val scalaUdf = ScalaUDF((s: Int) => s, IntegerType, 'a :: Nil, + val scalaUdf = ScalaUDF((s: Int) => s, IntegerType, Symbol("a") :: Nil, Option(ExpressionEncoder[Int]()) :: Nil) - val projectPlan = testRelation.select('a, 'b) - val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc) - val groupByPlan = orderByPlan.groupBy('a)(scalaUdf) + val projectPlan = testRelation.select(Symbol("a"), Symbol("b")) + val orderByPlan = projectPlan.orderBy(Symbol("a").asc, Symbol("b").desc) + val groupByPlan = orderByPlan.groupBy(Symbol("a"))(scalaUdf) val optimized = Optimize.execute(groupByPlan.analyze) val correctAnswer = groupByPlan.analyze comparePlans(optimized, correctAnswer) } test("should not remove orderBy with limit in groupBy clause") { - val projectPlan = testRelation.select('a, 'b) - val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc).limit(10) - val groupByPlan = orderByPlan.groupBy('a)(count(1)) + val projectPlan = testRelation.select(Symbol("a"), Symbol("b")) + val orderByPlan = projectPlan.orderBy(Symbol("a").asc, Symbol("b").desc).limit(10) + val groupByPlan = orderByPlan.groupBy(Symbol("a"))(count(1)) val optimized = Optimize.execute(groupByPlan.analyze) val correctAnswer = groupByPlan.analyze comparePlans(optimized, correctAnswer) } test("remove orderBy in join clause") { - val projectPlan = testRelation.select('a, 'b) - val unnecessaryOrderByPlan = projectPlan.orderBy('a.asc, 'b.desc) - val projectPlanB = testRelationB.select('d) - val joinPlan = unnecessaryOrderByPlan.join(projectPlanB).select('a, 'd) + val projectPlan = testRelation.select(Symbol("a"), Symbol("b")) + val unnecessaryOrderByPlan = projectPlan.orderBy(Symbol("a").asc, Symbol("b").desc) + val projectPlanB = testRelationB.select(Symbol("d")) + val joinPlan = unnecessaryOrderByPlan.join(projectPlanB).select(Symbol("a"), Symbol("d")) val optimized = Optimize.execute(joinPlan.analyze) - val correctAnswer = projectPlan.join(projectPlanB).select('a, 'd).analyze + val correctAnswer = projectPlan.join(projectPlanB).select(Symbol("a"), Symbol("d")).analyze comparePlans(optimized, correctAnswer) } test("should not remove orderBy with limit in join clause") { - val projectPlan = testRelation.select('a, 'b) - val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc).limit(10) - val projectPlanB = testRelationB.select('d) - val joinPlan = orderByPlan.join(projectPlanB).select('a, 'd) + val projectPlan = testRelation.select(Symbol("a"), Symbol("b")) + val orderByPlan = projectPlan.orderBy(Symbol("a").asc, Symbol("b").desc).limit(10) + val projectPlanB = testRelationB.select(Symbol("d")) + val joinPlan = orderByPlan.join(projectPlanB).select(Symbol("a"), Symbol("d")) val optimized = Optimize.execute(joinPlan.analyze) val correctAnswer = joinPlan.analyze comparePlans(optimized, correctAnswer) } test("SPARK-32318: should not remove orderBy in distribute statement") { - val projectPlan = testRelation.select('a, 'b) - val orderByPlan = projectPlan.orderBy('b.desc) - val distributedPlan = orderByPlan.distribute('a)(1) + val projectPlan = testRelation.select(Symbol("a"), Symbol("b")) + val orderByPlan = projectPlan.orderBy(Symbol("b").desc) + val distributedPlan = orderByPlan.distribute(Symbol("a"))(1) val optimized = Optimize.execute(distributedPlan.analyze) val correctAnswer = distributedPlan.analyze comparePlans(optimized, correctAnswer) } test("should not remove orderBy in left join clause if there is an outer limit") { - val projectPlan = testRelation.select('a, 'b) - val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc) - val projectPlanB = testRelationB.select('d) + val projectPlan = testRelation.select(Symbol("a"), Symbol("b")) + val orderByPlan = projectPlan.orderBy(Symbol("a").asc, Symbol("b").desc) + val projectPlanB = testRelationB.select(Symbol("d")) val joinPlan = orderByPlan .join(projectPlanB, LeftOuter) .limit(10) @@ -374,9 +396,9 @@ class EliminateSortsSuite extends AnalysisTest { } test("remove orderBy in right join clause event if there is an outer limit") { - val projectPlan = testRelation.select('a, 'b) - val orderByPlan = projectPlan.orderBy('a.asc, 'b.desc) - val projectPlanB = testRelationB.select('d) + val projectPlan = testRelation.select(Symbol("a"), Symbol("b")) + val orderByPlan = projectPlan.orderBy(Symbol("a").asc, Symbol("b").desc) + val projectPlanB = testRelationB.select(Symbol("d")) val joinPlan = orderByPlan .join(projectPlanB, RightOuter) .limit(10) @@ -390,8 +412,10 @@ class EliminateSortsSuite extends AnalysisTest { test("SPARK-33183: remove consecutive global sorts with the same ordering") { Seq( - (testRelation.orderBy('a.asc).orderBy('a.asc), testRelation.orderBy('a.asc)), - (testRelation.orderBy('a.asc, 'b.desc).orderBy('a.asc), testRelation.orderBy('a.asc)) + (testRelation.orderBy(Symbol("a").asc).orderBy(Symbol("a").asc), + testRelation.orderBy(Symbol("a").asc)), + (testRelation.orderBy(Symbol("a").asc, Symbol("b").desc).orderBy(Symbol("a").asc), + testRelation.orderBy(Symbol("a").asc)) ).foreach { case (ordered, answer) => val optimized = Optimize.execute(ordered.analyze) comparePlans(optimized, answer.analyze) @@ -399,24 +423,26 @@ class EliminateSortsSuite extends AnalysisTest { } test("SPARK-33183: remove consecutive local sorts with the same ordering") { - val orderedPlan = testRelation.sortBy('a.asc).sortBy('a.asc).sortBy('a.asc) + val orderedPlan = + testRelation.sortBy(Symbol("a").asc).sortBy(Symbol("a").asc).sortBy(Symbol("a").asc) val optimized = Optimize.execute(orderedPlan.analyze) - val correctAnswer = testRelation.sortBy('a.asc).analyze + val correctAnswer = testRelation.sortBy(Symbol("a").asc).analyze comparePlans(optimized, correctAnswer) } test("SPARK-33183: remove consecutive local sorts with different ordering") { - val orderedPlan = testRelation.sortBy('b.asc).sortBy('a.desc).sortBy('a.asc) + val orderedPlan = + testRelation.sortBy(Symbol("b").asc).sortBy(Symbol("a").desc).sortBy(Symbol("a").asc) val optimized = Optimize.execute(orderedPlan.analyze) - val correctAnswer = testRelation.sortBy('a.asc).analyze + val correctAnswer = testRelation.sortBy(Symbol("a").asc).analyze comparePlans(optimized, correctAnswer) } test("SPARK-33183: should keep global sort when child is a local sort with the same ordering") { - val correctAnswer = testRelation.orderBy('a.asc).analyze + val correctAnswer = testRelation.orderBy(Symbol("a").asc).analyze Seq( - testRelation.sortBy('a.asc).orderBy('a.asc), - testRelation.orderBy('a.asc).sortBy('a.asc).orderBy('a.asc) + testRelation.sortBy(Symbol("a").asc).orderBy(Symbol("a").asc), + testRelation.orderBy(Symbol("a").asc).sortBy(Symbol("a").asc).orderBy(Symbol("a").asc) ).foreach { ordered => val optimized = Optimize.execute(ordered.analyze) comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala index 4df1a145a271..1353e5bb929c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala @@ -45,25 +45,25 @@ class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper { } test("eliminate top level subquery") { - val input = LocalRelation('a.int, 'b.int) + val input = LocalRelation(Symbol("a").int, Symbol("b").int) val query = SubqueryAlias("a", input) comparePlans(afterOptimization(query), input) } test("eliminate mid-tree subquery") { - val input = LocalRelation('a.int, 'b.int) + val input = LocalRelation(Symbol("a").int, Symbol("b").int) val query = Filter(TrueLiteral, SubqueryAlias("a", input)) comparePlans( afterOptimization(query), - Filter(TrueLiteral, LocalRelation('a.int, 'b.int))) + Filter(TrueLiteral, LocalRelation(Symbol("a").int, Symbol("b").int))) } test("eliminate multiple subqueries") { - val input = LocalRelation('a.int, 'b.int) + val input = LocalRelation(Symbol("a").int, Symbol("b").int) val query = Filter(TrueLiteral, SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input)))) comparePlans( afterOptimization(query), - Filter(TrueLiteral, LocalRelation('a.int, 'b.int))) + Filter(TrueLiteral, LocalRelation(Symbol("a").int, Symbol("b").int))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExtractPythonUDFFromJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExtractPythonUDFFromJoinConditionSuite.scala index 77bfc0b3682a..516629159a51 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExtractPythonUDFFromJoinConditionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ExtractPythonUDFFromJoinConditionSuite.scala @@ -38,10 +38,10 @@ class ExtractPythonUDFFromJoinConditionSuite extends PlanTest { CheckCartesianProducts) :: Nil } - val attrA = 'a.int - val attrB = 'b.int - val attrC = 'c.int - val attrD = 'd.int + val attrA = Symbol("a").int + val attrB = Symbol("b").int + val attrC = Symbol("c").int + val attrD = Symbol("d").int val testRelationLeft = LocalRelation(attrA, attrB) val testRelationRight = LocalRelation(attrC, attrD) @@ -105,11 +105,11 @@ class ExtractPythonUDFFromJoinConditionSuite extends PlanTest { val query = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some(unevaluableJoinCond && 'a.attr === 'c.attr)) + condition = Some(unevaluableJoinCond && Symbol("a").attr === Symbol("c").attr)) val expected = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond).analyze + condition = Some(Symbol("a").attr === Symbol("c").attr)).where(unevaluableJoinCond).analyze val optimized = Optimize.execute(query.analyze) comparePlans(optimized, expected) } @@ -118,11 +118,11 @@ class ExtractPythonUDFFromJoinConditionSuite extends PlanTest { val query = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some(unevaluableJoinCond || 'a.attr === 'c.attr)) + condition = Some(unevaluableJoinCond || Symbol("a").attr === Symbol("c").attr)) val expected = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = None).where(unevaluableJoinCond || 'a.attr === 'c.attr).analyze + condition = None).where(unevaluableJoinCond || Symbol("a").attr === Symbol("c").attr).analyze comparePlanWithCrossJoinEnable(query, expected) } @@ -132,7 +132,7 @@ class ExtractPythonUDFFromJoinConditionSuite extends PlanTest { Seq(attrA, attrC), PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) - val condition = (unevaluableJoinCond || 'a.attr === 'c.attr) && pythonUDF1 + val condition = (unevaluableJoinCond || Symbol("a").attr === Symbol("c").attr) && pythonUDF1 val query = testRelationLeft.join( testRelationRight, @@ -151,16 +151,18 @@ class ExtractPythonUDFFromJoinConditionSuite extends PlanTest { Seq(attrA, attrC), PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) - val condition = (unevaluableJoinCond || pythonUDF1) && 'a.attr === 'c.attr + val condition = (unevaluableJoinCond || pythonUDF1) && Symbol("a").attr === Symbol("c").attr val query = testRelationLeft.join( testRelationRight, joinType = Inner, condition = Some(condition)) - val expected = testRelationLeft.join( - testRelationRight, - joinType = Inner, - condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond || pythonUDF1).analyze + val expected = testRelationLeft + .join( + testRelationRight, + joinType = Inner, + condition = Some(Symbol("a").attr === Symbol("c").attr)) + .where(unevaluableJoinCond || pythonUDF1).analyze val optimized = Optimize.execute(query.analyze) comparePlans(optimized, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownOnePassSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownOnePassSuite.scala index 6f1280c90e9d..a010ba5e7822 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownOnePassSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownOnePassSuite.scala @@ -42,12 +42,12 @@ class FilterPushdownOnePassSuite extends PlanTest { ) :: Nil } - val testRelation1 = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation2 = LocalRelation('a.int, 'd.int, 'e.int) + val testRelation1 = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) + val testRelation2 = LocalRelation(Symbol("a").int, Symbol("d").int, Symbol("e").int) test("really simple predicate push down") { - val x = testRelation1.subquery('x) - val y = testRelation2.subquery('y) + val x = testRelation1.subquery(Symbol("x")) + val y = testRelation2.subquery(Symbol("y")) val originalQuery = x.join(y).where("x.a".attr === 1) @@ -58,8 +58,8 @@ class FilterPushdownOnePassSuite extends PlanTest { } test("push down conjunctive predicates") { - val x = testRelation1.subquery('x) - val y = testRelation2.subquery('y) + val x = testRelation1.subquery(Symbol("x")) + val y = testRelation2.subquery(Symbol("y")) val originalQuery = x.join(y).where("x.a".attr === 1 && "y.d".attr < 1) @@ -70,8 +70,8 @@ class FilterPushdownOnePassSuite extends PlanTest { } test("push down predicates for simple joins") { - val x = testRelation1.subquery('x) - val y = testRelation2.subquery('y) + val x = testRelation1.subquery(Symbol("x")) + val y = testRelation2.subquery(Symbol("y")) val originalQuery = x.where("x.c".attr < 0) @@ -87,8 +87,8 @@ class FilterPushdownOnePassSuite extends PlanTest { } test("push down top-level filters for cascading joins") { - val x = testRelation1.subquery('x) - val y = testRelation2.subquery('y) + val x = testRelation1.subquery(Symbol("x")) + val y = testRelation2.subquery(Symbol("y")) val originalQuery = y.join(x).join(x).join(x).join(x).join(x).where("y.d".attr === 0) @@ -100,9 +100,9 @@ class FilterPushdownOnePassSuite extends PlanTest { } test("push down predicates for tree-like joins") { - val x = testRelation1.subquery('x) - val y1 = testRelation2.subquery('y1) - val y2 = testRelation2.subquery('y2) + val x = testRelation1.subquery(Symbol("x")) + val y1 = testRelation2.subquery(Symbol("y1")) + val y2 = testRelation2.subquery(Symbol("y2")) val originalQuery = y1.join(x).join(x) @@ -118,64 +118,66 @@ class FilterPushdownOnePassSuite extends PlanTest { } test("push down through join and project") { - val x = testRelation1.subquery('x) - val y = testRelation2.subquery('y) + val x = testRelation1.subquery(Symbol("x")) + val y = testRelation2.subquery(Symbol("y")) val originalQuery = - x.where('a > 0).select('a, 'b) - .join(y.where('d < 100).select('e)) + x.where(Symbol("a") > 0).select(Symbol("a"), Symbol("b")) + .join(y.where(Symbol("d") < 100).select(Symbol("e"))) .where("x.a".attr < 100) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = - x.where('a > 0 && 'a < 100).select('a, 'b) - .join(y.where('d < 100).select('e)).analyze + x.where(Symbol("a") > 0 && Symbol("a") < 100).select(Symbol("a"), Symbol("b")) + .join(y.where(Symbol("d") < 100).select(Symbol("e"))).analyze comparePlans(optimized, correctAnswer) } test("push down through deep projects") { - val x = testRelation1.subquery('x) + val x = testRelation1.subquery(Symbol("x")) val originalQuery = - x.select(('a + 1) as 'a1, 'b) - .select(('a1 + 1) as 'a2, 'b) - .select(('a2 + 1) as 'a3, 'b) - .select(('a3 + 1) as 'a4, 'b) - .select('b) - .where('b > 0) + x.select((Symbol("a") + 1) as Symbol("a1"), Symbol("b")) + .select((Symbol("a1") + 1) as Symbol("a2"), Symbol("b")) + .select((Symbol("a2") + 1) as Symbol("a3"), Symbol("b")) + .select((Symbol("a3") + 1) as Symbol("a4"), Symbol("b")) + .select(Symbol("b")) + .where(Symbol("b") > 0) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = - x.where('b > 0) - .select(('a + 1) as 'a1, 'b) - .select(('a1 + 1) as 'a2, 'b) - .select(('a2 + 1) as 'a3, 'b) - .select(('a3 + 1) as 'a4, 'b) - .select('b).analyze + x.where(Symbol("b") > 0) + .select((Symbol("a") + 1) as Symbol("a1"), Symbol("b")) + .select((Symbol("a1") + 1) as Symbol("a2"), Symbol("b")) + .select((Symbol("a2") + 1) as Symbol("a3"), Symbol("b")) + .select((Symbol("a3") + 1) as Symbol("a4"), Symbol("b")) + .select(Symbol("b")).analyze comparePlans(optimized, correctAnswer) } test("push down through aggregate and join") { - val x = testRelation1.subquery('x) - val y = testRelation2.subquery('y) + val x = testRelation1.subquery(Symbol("x")) + val y = testRelation2.subquery(Symbol("y")) val left = x - .where('c > 0) - .groupBy('a)('a, count('b)) - .subquery('left) + .where(Symbol("c") > 0) + .groupBy(Symbol("a"))(Symbol("a"), count(Symbol("b"))) + .subquery(Symbol("left")) val right = y - .where('d < 0) - .groupBy('a)('a, count('d)) - .subquery('right) + .where(Symbol("d") < 0) + .groupBy(Symbol("a"))(Symbol("a"), count(Symbol("d"))) + .subquery(Symbol("right")) val originalQuery = left .join(right).where("left.a".attr < 100 && "right.a".attr < 100) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = - x.where('c > 0 && 'a < 100).groupBy('a)('a, count('b)) - .join(y.where('d < 0 && 'a < 100).groupBy('a)('a, count('d))) + x.where(Symbol("c") > 0 && Symbol("a") < 100) + .groupBy(Symbol("a"))(Symbol("a"), count(Symbol("b"))) + .join(y.where(Symbol("d") < 0 && Symbol("a") < 100) + .groupBy(Symbol("a"))(Symbol("a"), count(Symbol("d")))) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index c518fdded211..cc1c7643c361 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -46,10 +46,10 @@ class FilterPushdownSuite extends PlanTest { PushDownPredicates) :: Nil } - val attrA = 'a.int - val attrB = 'b.int - val attrC = 'c.int - val attrD = 'd.int + val attrA = Symbol("a").int + val attrB = Symbol("b").int + val attrC = Symbol("c").int + val attrD = Symbol("d").int val testRelation = LocalRelation(attrA, attrB, attrC) @@ -58,8 +58,8 @@ class FilterPushdownSuite extends PlanTest { val simpleDisjunctivePredicate = ("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11) val expectedPredicatePushDownResult = { - val left = testRelation.where(('a > 3 || 'a > 1)).subquery('x) - val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + val left = testRelation.where((Symbol("a") > 3 || Symbol("a") > 1)).subquery(Symbol("x")) + val right = testRelation.where(Symbol("a") > 13 || Symbol("a") > 11).subquery(Symbol("y")) left.join(right, condition = Some("x.b".attr === "y.b".attr && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))).analyze } @@ -68,13 +68,13 @@ class FilterPushdownSuite extends PlanTest { test("eliminate subqueries") { val originalQuery = testRelation - .subquery('y) - .select('a) + .subquery(Symbol("y")) + .select(Symbol("a")) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a.attr) + .select(Symbol("a").attr) .analyze comparePlans(optimized, correctAnswer) @@ -84,14 +84,14 @@ class FilterPushdownSuite extends PlanTest { test("simple push down") { val originalQuery = testRelation - .select('a) - .where('a === 1) + .select(Symbol("a")) + .where(Symbol("a") === 1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where('a === 1) - .select('a) + .where(Symbol("a") === 1) + .select(Symbol("a")) .analyze comparePlans(optimized, correctAnswer) @@ -100,13 +100,13 @@ class FilterPushdownSuite extends PlanTest { test("combine redundant filters") { val originalQuery = testRelation - .where('a === 1 && 'b === 1) - .where('a === 1 && 'c === 1) + .where(Symbol("a") === 1 && Symbol("b") === 1) + .where(Symbol("a") === 1 && Symbol("c") === 1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where('a === 1 && 'b === 1 && 'c === 1) + .where(Symbol("a") === 1 && Symbol("b") === 1 && Symbol("c") === 1) .analyze comparePlans(optimized, correctAnswer) @@ -115,8 +115,8 @@ class FilterPushdownSuite extends PlanTest { test("do not combine non-deterministic filters even if they are identical") { val originalQuery = testRelation - .where(Rand(0) > 0.1 && 'a === 1) - .where(Rand(0) > 0.1 && 'a === 1).analyze + .where(Rand(0) > 0.1 && Symbol("a") === 1) + .where(Rand(0) > 0.1 && Symbol("a") === 1).analyze val optimized = Optimize.execute(originalQuery) @@ -126,15 +126,15 @@ class FilterPushdownSuite extends PlanTest { test("SPARK-16164: Filter pushdown should keep the ordering in the logical plan") { val originalQuery = testRelation - .where('a === 1) - .select('a, 'b) - .where('b === 1) + .where(Symbol("a") === 1) + .select(Symbol("a"), Symbol("b")) + .where(Symbol("b") === 1) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where('a === 1 && 'b === 1) - .select('a, 'b) + .where(Symbol("a") === 1 && Symbol("b") === 1) + .select(Symbol("a"), Symbol("b")) .analyze // We can not use comparePlans here because it normalized the plan. @@ -142,7 +142,7 @@ class FilterPushdownSuite extends PlanTest { } test("SPARK-16994: filter should not be pushed through limit") { - val originalQuery = testRelation.limit(10).where('a === 1).analyze + val originalQuery = testRelation.limit(10).where(Symbol("a") === 1).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, originalQuery) } @@ -150,15 +150,15 @@ class FilterPushdownSuite extends PlanTest { test("can't push without rewrite") { val originalQuery = testRelation - .select('a + 'b as 'e) - .where('e === 1) + .select(Symbol("a") + Symbol("b") as Symbol("e")) + .where(Symbol("e") === 1) .analyze val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where('a + 'b === 1) - .select('a + 'b as 'e) + .where(Symbol("a") + Symbol("b") === 1) + .select(Symbol("a") + Symbol("b") as Symbol("e")) .analyze comparePlans(optimized, correctAnswer) @@ -166,15 +166,15 @@ class FilterPushdownSuite extends PlanTest { test("nondeterministic: can always push down filter through project with deterministic field") { val originalQuery = testRelation - .select('a) - .where(Rand(10) > 5 || 'a > 5) + .select(Symbol("a")) + .where(Rand(10) > 5 || Symbol("a") > 5) .analyze val optimized = Optimize.execute(originalQuery) val correctAnswer = testRelation - .where(Rand(10) > 5 || 'a > 5) - .select('a) + .where(Rand(10) > 5 || Symbol("a") > 5) + .select(Symbol("a")) .analyze comparePlans(optimized, correctAnswer) @@ -182,8 +182,8 @@ class FilterPushdownSuite extends PlanTest { test("nondeterministic: can't push down filter through project with nondeterministic field") { val originalQuery = testRelation - .select(Rand(10).as('rand), 'a) - .where('a > 5) + .select(Rand(10).as(Symbol("rand")), Symbol("a")) + .where(Symbol("a") > 5) .analyze val optimized = Optimize.execute(originalQuery) @@ -193,8 +193,8 @@ class FilterPushdownSuite extends PlanTest { test("nondeterministic: can't push down filter through aggregate with nondeterministic field") { val originalQuery = testRelation - .groupBy('a)('a, Rand(10).as('rand)) - .where('a > 5) + .groupBy(Symbol("a"))(Symbol("a"), Rand(10).as(Symbol("rand"))) + .where(Symbol("a") > 5) .analyze val optimized = Optimize.execute(originalQuery) @@ -204,15 +204,15 @@ class FilterPushdownSuite extends PlanTest { test("nondeterministic: push down part of filter through aggregate with deterministic field") { val originalQuery = testRelation - .groupBy('a)('a) - .where('a > 5 && Rand(10) > 5) + .groupBy(Symbol("a"))(Symbol("a")) + .where(Symbol("a") > 5 && Rand(10) > 5) .analyze val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where('a > 5) - .groupBy('a)('a) + .where(Symbol("a") > 5) + .groupBy(Symbol("a"))(Symbol("a")) .where(Rand(10) > 5) .analyze @@ -221,22 +221,22 @@ class FilterPushdownSuite extends PlanTest { test("filters: combines filters") { val originalQuery = testRelation - .select('a) - .where('a === 1) - .where('a === 2) + .select(Symbol("a")) + .where(Symbol("a") === 1) + .where(Symbol("a") === 2) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where('a === 1 && 'a === 2) - .select('a).analyze + .where(Symbol("a") === 1 && Symbol("a") === 2) + .select(Symbol("a")).analyze comparePlans(optimized, correctAnswer) } test("joins: push to either side") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = { x.join(y) @@ -245,8 +245,8 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('b === 1) - val right = testRelation.where('b === 2) + val left = testRelation.where(Symbol("b") === 1) + val right = testRelation.where(Symbol("b") === 2) val correctAnswer = left.join(right).analyze @@ -254,8 +254,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: push to one side") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = { x.join(y) @@ -263,7 +263,7 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('b === 1) + val left = testRelation.where(Symbol("b") === 1) val right = testRelation val correctAnswer = left.join(right).analyze @@ -272,8 +272,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: do not push down non-deterministic filters into join condition") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) val originalQuery = x.join(y).where(Rand(10) > 5.0).analyze val optimized = Optimize.execute(originalQuery) @@ -282,8 +282,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: push to one side after transformCondition") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) val originalQuery = { x.join(y) @@ -292,7 +292,7 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('a === 1) + val left = testRelation.where(Symbol("a") === 1) val right = testRelation1 val correctAnswer = left.join(right, condition = Some("d".attr === "b".attr || "d".attr === "c".attr)).analyze @@ -301,8 +301,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: rewrite filter to push to either side") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = { x.join(y) @@ -310,8 +310,8 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('b === 1) - val right = testRelation.where('b === 2) + val left = testRelation.where(Symbol("b") === 1) + val right = testRelation.where(Symbol("b") === 2) val correctAnswer = left.join(right).analyze @@ -319,16 +319,16 @@ class FilterPushdownSuite extends PlanTest { } test("joins: push down left semi join") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) val originalQuery = { x.join(y, LeftSemi, Option("x.a".attr === "y.d".attr && "x.b".attr >= 1 && "y.d".attr >= 2)) } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('b >= 1) - val right = testRelation1.where('d >= 2) + val left = testRelation.where(Symbol("b") >= 1) + val right = testRelation1.where(Symbol("d") >= 2) val correctAnswer = left.join(right, LeftSemi, Option("a".attr === "d".attr)).analyze @@ -336,8 +336,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: push down left outer join #1") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = { x.join(y, LeftOuter) @@ -345,7 +345,7 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('b === 1) + val left = testRelation.where(Symbol("b") === 1) val correctAnswer = left.join(y, LeftOuter).where("y.b".attr === 2).analyze @@ -353,8 +353,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: push down right outer join #1") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = { x.join(y, RightOuter) @@ -362,7 +362,7 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val right = testRelation.where('b === 2).subquery('d) + val right = testRelation.where(Symbol("b") === 2).subquery(Symbol("d")) val correctAnswer = x.join(right, RightOuter).where("x.b".attr === 1).analyze @@ -370,8 +370,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: push down left outer join #2") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = { x.join(y, LeftOuter, Some("x.b".attr === 1)) @@ -379,7 +379,7 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('b === 2).subquery('d) + val left = testRelation.where(Symbol("b") === 2).subquery(Symbol("d")) val correctAnswer = left.join(y, LeftOuter, Some("d.b".attr === 1)).where("y.b".attr === 2).analyze @@ -387,8 +387,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: push down right outer join #2") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = { x.join(y, RightOuter, Some("y.b".attr === 1)) @@ -396,7 +396,7 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val right = testRelation.where('b === 2).subquery('d) + val right = testRelation.where(Symbol("b") === 2).subquery(Symbol("d")) val correctAnswer = x.join(right, RightOuter, Some("d.b".attr === 1)).where("x.b".attr === 2).analyze @@ -404,8 +404,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: push down left outer join #3") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = { x.join(y, LeftOuter, Some("y.b".attr === 1)) @@ -413,8 +413,8 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('b === 2).subquery('l) - val right = testRelation.where('b === 1).subquery('r) + val left = testRelation.where(Symbol("b") === 2).subquery(Symbol("l")) + val right = testRelation.where(Symbol("b") === 1).subquery(Symbol("r")) val correctAnswer = left.join(right, LeftOuter).where("r.b".attr === 2).analyze @@ -422,8 +422,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: push down right outer join #3") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = { x.join(y, RightOuter, Some("y.b".attr === 1)) @@ -431,7 +431,7 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val right = testRelation.where('b === 2).subquery('r) + val right = testRelation.where(Symbol("b") === 2).subquery(Symbol("r")) val correctAnswer = x.join(right, RightOuter, Some("r.b".attr === 1)).where("x.b".attr === 2).analyze @@ -439,8 +439,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: push down left outer join #4") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = { x.join(y, LeftOuter, Some("y.b".attr === 1)) @@ -448,8 +448,8 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('b === 2).subquery('l) - val right = testRelation.where('b === 1).subquery('r) + val left = testRelation.where(Symbol("b") === 2).subquery(Symbol("l")) + val right = testRelation.where(Symbol("b") === 1).subquery(Symbol("r")) val correctAnswer = left.join(right, LeftOuter).where("r.b".attr === 2 && "l.c".attr === "r.c".attr).analyze @@ -457,8 +457,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: push down right outer join #4") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = { x.join(y, RightOuter, Some("y.b".attr === 1)) @@ -466,8 +466,8 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.subquery('l) - val right = testRelation.where('b === 2).subquery('r) + val left = testRelation.subquery(Symbol("l")) + val right = testRelation.where(Symbol("b") === 2).subquery(Symbol("r")) val correctAnswer = left.join(right, RightOuter, Some("r.b".attr === 1)). where("l.b".attr === 2 && "l.c".attr === "r.c".attr).analyze @@ -476,8 +476,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: push down left outer join #5") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = { x.join(y, LeftOuter, Some("y.b".attr === 1 && "x.a".attr === 3)) @@ -485,8 +485,8 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('b === 2).subquery('l) - val right = testRelation.where('b === 1).subquery('r) + val left = testRelation.where(Symbol("b") === 2).subquery(Symbol("l")) + val right = testRelation.where(Symbol("b") === 1).subquery(Symbol("r")) val correctAnswer = left.join(right, LeftOuter, Some("l.a".attr===3)). where("r.b".attr === 2 && "l.c".attr === "r.c".attr).analyze @@ -495,8 +495,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: push down right outer join #5") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = { x.join(y, RightOuter, Some("y.b".attr === 1 && "x.a".attr === 3)) @@ -504,8 +504,8 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('a === 3).subquery('l) - val right = testRelation.where('b === 2).subquery('r) + val left = testRelation.where(Symbol("a") === 3).subquery(Symbol("l")) + val right = testRelation.where(Symbol("b") === 2).subquery(Symbol("r")) val correctAnswer = left.join(right, RightOuter, Some("r.b".attr === 1)). where("l.b".attr === 2 && "l.c".attr === "r.c".attr).analyze @@ -514,8 +514,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: can't push down") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = { x.join(y, condition = Some("x.b".attr === "y.b".attr)) @@ -526,8 +526,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: conjunctive predicates") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = { x.join(y) @@ -535,8 +535,8 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('a === 1).subquery('x) - val right = testRelation.where('a === 1).subquery('y) + val left = testRelation.where(Symbol("a") === 1).subquery(Symbol("x")) + val right = testRelation.where(Symbol("a") === 1).subquery(Symbol("y")) val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr)) .analyze @@ -545,8 +545,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: conjunctive predicates #2") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = { x.join(y) @@ -554,8 +554,8 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('a === 1).subquery('x) - val right = testRelation.subquery('y) + val left = testRelation.where(Symbol("a") === 1).subquery(Symbol("x")) + val right = testRelation.subquery(Symbol("y")) val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr)) .analyze @@ -564,9 +564,9 @@ class FilterPushdownSuite extends PlanTest { } test("joins: conjunctive predicates #3") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) - val z = testRelation.subquery('z) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) + val z = testRelation.subquery(Symbol("z")) val originalQuery = { z.join(x.join(y)) @@ -575,9 +575,9 @@ class FilterPushdownSuite extends PlanTest { } val optimized = Optimize.execute(originalQuery.analyze) - val lleft = testRelation.where('a >= 3).subquery('z) - val left = testRelation.where('a === 1).subquery('x) - val right = testRelation.subquery('y) + val lleft = testRelation.where(Symbol("a") >= 3).subquery(Symbol("z")) + val left = testRelation.where(Symbol("a") === 1).subquery(Symbol("x")) + val right = testRelation.subquery(Symbol("y")) val correctAnswer = lleft.join( left.join(right, condition = Some("x.b".attr === "y.b".attr)), @@ -588,8 +588,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: push down where clause into left anti join") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = x.join(y, LeftAnti, Some("x.b".attr === "y.b".attr)) .where("x.a".attr > 10) @@ -603,8 +603,8 @@ class FilterPushdownSuite extends PlanTest { } test("joins: only push down join conditions to the right of a left anti join") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = x.join(y, LeftAnti, @@ -620,9 +620,9 @@ class FilterPushdownSuite extends PlanTest { } test("joins: only push down join conditions to the right of an existence join") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) - val fillerVal = 'val.boolean + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) + val fillerVal = Symbol("val").boolean val originalQuery = x.join(y, ExistenceJoin(fillerVal), @@ -637,19 +637,20 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } - val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) + val testRelationWithArrayType = + LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c_arr").array(IntegerType)) test("generate: predicate referenced no generated column") { val originalQuery = { testRelationWithArrayType - .generate(Explode('c_arr), alias = Some("arr")) - .where(('b >= 5) && ('a > 6)) + .generate(Explode(Symbol("c_arr")), alias = Some("arr")) + .where((Symbol("b") >= 5) && (Symbol("a") > 6)) } val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = { testRelationWithArrayType - .where(('b >= 5) && ('a > 6)) - .generate(Explode('c_arr), alias = Some("arr")).analyze + .where((Symbol("b") >= 5) && (Symbol("a") > 6)) + .generate(Explode(Symbol("c_arr")), alias = Some("arr")).analyze } comparePlans(optimized, correctAnswer) @@ -658,15 +659,15 @@ class FilterPushdownSuite extends PlanTest { test("generate: non-deterministic predicate referenced no generated column") { val originalQuery = { testRelationWithArrayType - .generate(Explode('c_arr), alias = Some("arr")) - .where(('b >= 5) && ('a + Rand(10).as("rnd") > 6) && ('col > 6)) + .generate(Explode(Symbol("c_arr")), alias = Some("arr")) + .where((Symbol("b") >= 5) && (Symbol("a") + Rand(10).as("rnd") > 6) && (Symbol("col") > 6)) } val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = { testRelationWithArrayType - .where('b >= 5) - .generate(Explode('c_arr), alias = Some("arr")) - .where('a + Rand(10).as("rnd") > 6 && 'col > 6) + .where(Symbol("b") >= 5) + .generate(Explode(Symbol("c_arr")), alias = Some("arr")) + .where(Symbol("a") + Rand(10).as("rnd") > 6 && Symbol("col") > 6) .analyze } @@ -674,18 +675,18 @@ class FilterPushdownSuite extends PlanTest { } test("generate: part of conjuncts referenced generated column") { - val generator = Explode('c_arr) + val generator = Explode(Symbol("c_arr")) val originalQuery = { testRelationWithArrayType .generate(generator, alias = Some("arr"), outputNames = Seq("c")) - .where(('b >= 5) && ('c > 6)) + .where((Symbol("b") >= 5) && (Symbol("c") > 6)) } val optimized = Optimize.execute(originalQuery.analyze) val referenceResult = { testRelationWithArrayType - .where('b >= 5) + .where(Symbol("b") >= 5) .generate(generator, alias = Some("arr"), outputNames = Seq("c")) - .where('c > 6).analyze + .where(Symbol("c") > 6).analyze } // Since newly generated columns get different ids every time being analyzed @@ -705,8 +706,8 @@ class FilterPushdownSuite extends PlanTest { test("generate: all conjuncts referenced generated column") { val originalQuery = { testRelationWithArrayType - .generate(Explode('c_arr), alias = Some("arr")) - .where(('col > 6) || ('b > 5)).analyze + .generate(Explode(Symbol("c_arr")), alias = Some("arr")) + .where((Symbol("col") > 6) || (Symbol("b") > 5)).analyze } val optimized = Optimize.execute(originalQuery) @@ -715,24 +716,24 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: push down filter when filter on group by expression") { val originalQuery = testRelation - .groupBy('a)('a, count('b) as 'c) - .select('a, 'c) - .where('a === 2) + .groupBy(Symbol("a"))(Symbol("a"), count(Symbol("b")) as Symbol("c")) + .select(Symbol("a"), Symbol("c")) + .where(Symbol("a") === 2) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where('a === 2) - .groupBy('a)('a, count('b) as 'c) + .where(Symbol("a") === 2) + .groupBy(Symbol("a"))(Symbol("a"), count(Symbol("b")) as Symbol("c")) .analyze comparePlans(optimized, correctAnswer) } test("aggregate: don't push down filter when filter not on group by expression") { val originalQuery = testRelation - .select('a, 'b) - .groupBy('a)('a, count('b) as 'c) - .where('c === 2L) + .select(Symbol("a"), Symbol("b")) + .groupBy(Symbol("a"))(Symbol("a"), count(Symbol("b")) as Symbol("c")) + .where(Symbol("c") === 2L) val optimized = Optimize.execute(originalQuery.analyze) @@ -741,17 +742,17 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: push down filters partially which are subset of group by expressions") { val originalQuery = testRelation - .select('a, 'b) - .groupBy('a)('a, count('b) as 'c) - .where('c === 2L && 'a === 3) + .select(Symbol("a"), Symbol("b")) + .groupBy(Symbol("a"))(Symbol("a"), count(Symbol("b")) as Symbol("c")) + .where(Symbol("c") === 2L && Symbol("a") === 3) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where('a === 3) - .select('a, 'b) - .groupBy('a)('a, count('b) as 'c) - .where('c === 2L) + .where(Symbol("a") === 3) + .select(Symbol("a"), Symbol("b")) + .groupBy(Symbol("a"))(Symbol("a"), count(Symbol("b")) as Symbol("c")) + .where(Symbol("c") === 2L) .analyze comparePlans(optimized, correctAnswer) @@ -759,17 +760,17 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: push down filters with alias") { val originalQuery = testRelation - .select('a, 'b) - .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) - .where(('c === 2L || 'aa > 4) && 'aa < 3) + .select(Symbol("a"), Symbol("b")) + .groupBy(Symbol("a"))((Symbol("a") + 1) as Symbol("aa"), count(Symbol("b")) as Symbol("c")) + .where((Symbol("c") === 2L || Symbol("aa") > 4) && Symbol("aa") < 3) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where('a + 1 < 3) - .select('a, 'b) - .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) - .where('c === 2L || 'aa > 4) + .where(Symbol("a") + 1 < 3) + .select(Symbol("a"), Symbol("b")) + .groupBy(Symbol("a"))((Symbol("a") + 1) as Symbol("aa"), count(Symbol("b")) as Symbol("c")) + .where(Symbol("c") === 2L || Symbol("aa") > 4) .analyze comparePlans(optimized, correctAnswer) @@ -777,17 +778,17 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: push down filters with literal") { val originalQuery = testRelation - .select('a, 'b) - .groupBy('a)('a, count('b) as 'c, "s" as 'd) - .where('c === 2L && 'd === "s") + .select(Symbol("a"), Symbol("b")) + .groupBy(Symbol("a"))(Symbol("a"), count(Symbol("b")) as Symbol("c"), "s" as Symbol("d")) + .where(Symbol("c") === 2L && Symbol("d") === "s") val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .where("s" === "s") - .select('a, 'b) - .groupBy('a)('a, count('b) as 'c, "s" as 'd) - .where('c === 2L) + .select(Symbol("a"), Symbol("b")) + .groupBy(Symbol("a"))(Symbol("a"), count(Symbol("b")) as Symbol("c"), "s" as Symbol("d")) + .where(Symbol("c") === 2L) .analyze comparePlans(optimized, correctAnswer) @@ -795,16 +796,18 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: don't push down filters that are nondeterministic") { val originalQuery = testRelation - .select('a, 'b) - .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) - .where('c === 2L && 'aa + Rand(10).as("rnd") === 3 && 'rnd === 5) + .select(Symbol("a"), Symbol("b")) + .groupBy(Symbol("a"))(Symbol("a") + Rand(10) as Symbol("aa"), + count(Symbol("b")) as Symbol("c"), Rand(11).as("rnd")) + .where(Symbol("c") === 2L && Symbol("aa") + Rand(10).as("rnd") === 3 && Symbol("rnd") === 5) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a, 'b) - .groupBy('a)('a + Rand(10) as 'aa, count('b) as 'c, Rand(11).as("rnd")) - .where('c === 2L && 'aa + Rand(10).as("rnd") === 3 && 'rnd === 5) + .select(Symbol("a"), Symbol("b")) + .groupBy(Symbol("a"))(Symbol("a") + Rand(10) as Symbol("aa"), + count(Symbol("b")) as Symbol("c"), Rand(11).as("rnd")) + .where(Symbol("c") === 2L && Symbol("aa") + Rand(10).as("rnd") === 3 && Symbol("rnd") === 5) .analyze comparePlans(optimized, correctAnswer) @@ -812,15 +815,15 @@ class FilterPushdownSuite extends PlanTest { test("SPARK-17712: aggregate: don't push down filters that are data-independent") { val originalQuery = LocalRelation.apply(testRelation.output, Seq.empty) - .select('a, 'b) - .groupBy('a)(count('a)) + .select(Symbol("a"), Symbol("b")) + .groupBy(Symbol("a"))(count(Symbol("a"))) .where(false) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a, 'b) - .groupBy('a)(count('a)) + .select(Symbol("a"), Symbol("b")) + .groupBy(Symbol("a"))(count(Symbol("a"))) .where(false) .analyze @@ -829,7 +832,7 @@ class FilterPushdownSuite extends PlanTest { test("aggregate: don't push filters if the aggregate has no grouping expressions") { val originalQuery = LocalRelation.apply(testRelation.output, Seq.empty) - .select('a, 'b) + .select(Symbol("a"), Symbol("b")) .groupBy()(count(1)) .where(false) @@ -841,17 +844,17 @@ class FilterPushdownSuite extends PlanTest { } test("union") { - val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + val testRelation2 = LocalRelation(Symbol("d").int, Symbol("e").int, Symbol("f").int) val originalQuery = Union(Seq(testRelation, testRelation2)) - .where('a === 2L && 'b + Rand(10).as("rnd") === 3 && 'c > 5L) + .where(Symbol("a") === 2L && Symbol("b") + Rand(10).as("rnd") === 3 && Symbol("c") > 5L) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Union(Seq( - testRelation.where('a === 2L && 'c > 5L), - testRelation2.where('d === 2L && 'f > 5L))) - .where('b + Rand(10).as("rnd") === 3) + testRelation.where(Symbol("a") === 2L && Symbol("c") > 5L), + testRelation2.where(Symbol("d") === 2L && Symbol("f") > 5L))) + .where(Symbol("b") + Rand(10).as("rnd") === 3) .analyze comparePlans(optimized, correctAnswer) @@ -859,7 +862,7 @@ class FilterPushdownSuite extends PlanTest { test("expand") { val agg = testRelation - .groupBy(Cube(Seq('a, 'b)))('a, 'b, sum('c)) + .groupBy(Cube(Seq(Symbol("a"), Symbol("b"))))(Symbol("a"), Symbol("b"), sum(Symbol("c"))) .analyze .asInstanceOf[Aggregate] @@ -873,9 +876,9 @@ class FilterPushdownSuite extends PlanTest { } test("predicate subquery: push down simple") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) - val z = LocalRelation('a.int, 'b.int, 'c.int).subquery('z) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) + val z = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int).subquery(Symbol("z")) val query = x .join(y, Inner, Option("x.a".attr === "y.a".attr)) @@ -890,10 +893,10 @@ class FilterPushdownSuite extends PlanTest { } test("predicate subquery: push down complex") { - val w = testRelation.subquery('w) - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) - val z = LocalRelation('a.int, 'b.int, 'c.int).subquery('z) + val w = testRelation.subquery(Symbol("w")) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) + val z = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int).subquery(Symbol("z")) val query = w .join(x, Inner, Option("w.a".attr === "x.a".attr)) @@ -910,9 +913,9 @@ class FilterPushdownSuite extends PlanTest { } test("SPARK-20094: don't push predicate with IN subquery into join condition") { - val x = testRelation.subquery('x) - val z = testRelation.subquery('z) - val w = testRelation1.subquery('w) + val x = testRelation.subquery(Symbol("x")) + val z = testRelation.subquery(Symbol("z")) + val w = testRelation1.subquery(Symbol("w")) val queryPlan = x .join(z) @@ -930,66 +933,80 @@ class FilterPushdownSuite extends PlanTest { } test("Window: predicate push down -- basic") { - val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) + val winExpr = windowExpr(count(Symbol("b")), + windowSpec(Symbol("a") :: Nil, Symbol("b").asc :: Nil, UnspecifiedFrame)) - val originalQuery = testRelation.select('a, 'b, 'c, winExpr.as('window)).where('a > 1) + val originalQuery = testRelation.select(Symbol("a"), Symbol("b"), Symbol("c"), + winExpr.as(Symbol("window"))).where(Symbol("a") > 1) val correctAnswer = testRelation - .where('a > 1).select('a, 'b, 'c) - .window(winExpr.as('window) :: Nil, 'a :: Nil, 'b.asc :: Nil) - .select('a, 'b, 'c, 'window).analyze + .where(Symbol("a") > 1).select(Symbol("a"), Symbol("b"), Symbol("c")) + .window(winExpr.as(Symbol("window")) :: Nil, Symbol("a") :: Nil, Symbol("b").asc :: Nil) + .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("window")).analyze comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer) } test("Window: predicate push down -- predicates with compound predicate using only one column") { val winExpr = - windowExpr(count('b), windowSpec('a.attr :: 'b.attr :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) + windowExpr(count(Symbol("b")), windowSpec(Symbol("a").attr :: Symbol("b").attr :: Nil, + Symbol("b").asc :: Nil, UnspecifiedFrame)) - val originalQuery = testRelation.select('a, 'b, 'c, winExpr.as('window)).where('a * 3 > 15) + val originalQuery = testRelation.select(Symbol("a"), Symbol("b"), Symbol("c"), + winExpr.as(Symbol("window"))).where(Symbol("a") * 3 > 15) val correctAnswer = testRelation - .where('a * 3 > 15).select('a, 'b, 'c) - .window(winExpr.as('window) :: Nil, 'a.attr :: 'b.attr :: Nil, 'b.asc :: Nil) - .select('a, 'b, 'c, 'window).analyze + .where(Symbol("a") * 3 > 15).select(Symbol("a"), Symbol("b"), Symbol("c")) + .window(winExpr.as(Symbol("window")) :: Nil, + Symbol("a").attr :: Symbol("b").attr :: Nil, Symbol("b").asc :: Nil) + .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("window")).analyze comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer) } test("Window: predicate push down -- multi window expressions with the same window spec") { - val winSpec = windowSpec('a.attr :: 'b.attr :: Nil, 'b.asc :: Nil, UnspecifiedFrame) - val winExpr1 = windowExpr(count('b), winSpec) - val winExpr2 = windowExpr(sum('b), winSpec) + val winSpec = windowSpec(Symbol("a").attr :: Symbol("b").attr :: Nil, + Symbol("b").asc :: Nil, UnspecifiedFrame) + val winExpr1 = windowExpr(count(Symbol("b")), winSpec) + val winExpr2 = windowExpr(sum(Symbol("b")), winSpec) val originalQuery = testRelation - .select('a, 'b, 'c, winExpr1.as('window1), winExpr2.as('window2)).where('a > 1) + .select(Symbol("a"), Symbol("b"), Symbol("c"), winExpr1.as(Symbol("window1")), + winExpr2.as(Symbol("window2"))).where(Symbol("a") > 1) val correctAnswer = testRelation - .where('a > 1).select('a, 'b, 'c) - .window(winExpr1.as('window1) :: winExpr2.as('window2) :: Nil, - 'a.attr :: 'b.attr :: Nil, 'b.asc :: Nil) - .select('a, 'b, 'c, 'window1, 'window2).analyze + .where(Symbol("a") > 1).select(Symbol("a"), Symbol("b"), Symbol("c")) + .window(winExpr1.as(Symbol("window1")) :: winExpr2.as(Symbol("window2")) :: Nil, + Symbol("a").attr :: Symbol("b").attr :: Nil, Symbol("b").asc :: Nil) + .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("window1"), Symbol("window2")).analyze comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer) } test("Window: predicate push down -- multi window specification - 1") { // order by clauses are different between winSpec1 and winSpec2 - val winSpec1 = windowSpec('a.attr :: 'b.attr :: Nil, 'b.asc :: Nil, UnspecifiedFrame) - val winExpr1 = windowExpr(count('b), winSpec1) - val winSpec2 = windowSpec('a.attr :: 'b.attr :: Nil, 'a.asc :: Nil, UnspecifiedFrame) - val winExpr2 = windowExpr(count('b), winSpec2) + val winSpec1 = windowSpec(Symbol("a").attr :: Symbol("b").attr :: Nil, + Symbol("b").asc :: Nil, UnspecifiedFrame) + val winExpr1 = windowExpr(count(Symbol("b")), winSpec1) + val winSpec2 = windowSpec(Symbol("a").attr :: Symbol("b").attr :: Nil, + Symbol("a").asc :: Nil, UnspecifiedFrame) + val winExpr2 = windowExpr(count(Symbol("b")), winSpec2) val originalQuery = testRelation - .select('a, 'b, 'c, winExpr1.as('window1), winExpr2.as('window2)).where('a > 1) + .select(Symbol("a"), Symbol("b"), Symbol("c"), winExpr1.as(Symbol("window1")), + winExpr2.as(Symbol("window2"))).where(Symbol("a") > 1) val correctAnswer1 = testRelation - .where('a > 1).select('a, 'b, 'c) - .window(winExpr1.as('window1) :: Nil, 'a.attr :: 'b.attr :: Nil, 'b.asc :: Nil) - .window(winExpr2.as('window2) :: Nil, 'a.attr :: 'b.attr :: Nil, 'a.asc :: Nil) - .select('a, 'b, 'c, 'window1, 'window2).analyze + .where(Symbol("a") > 1).select(Symbol("a"), Symbol("b"), Symbol("c")) + .window(winExpr1.as(Symbol("window1")) :: Nil, + Symbol("a").attr :: Symbol("b").attr :: Nil, Symbol("b").asc :: Nil) + .window(winExpr2.as(Symbol("window2")) :: Nil, + Symbol("a").attr :: Symbol("b").attr :: Nil, Symbol("a").asc :: Nil) + .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("window1"), Symbol("window2")).analyze val correctAnswer2 = testRelation - .where('a > 1).select('a, 'b, 'c) - .window(winExpr2.as('window2) :: Nil, 'a.attr :: 'b.attr :: Nil, 'a.asc :: Nil) - .window(winExpr1.as('window1) :: Nil, 'a.attr :: 'b.attr :: Nil, 'b.asc :: Nil) - .select('a, 'b, 'c, 'window1, 'window2).analyze + .where(Symbol("a") > 1).select(Symbol("a"), Symbol("b"), Symbol("c")) + .window(winExpr2.as(Symbol("window2")) :: Nil, + Symbol("a").attr :: Symbol("b").attr :: Nil, Symbol("a").asc :: Nil) + .window(winExpr1.as(Symbol("window1")) :: Nil, + Symbol("a").attr :: Symbol("b").attr :: Nil, Symbol("b").asc :: Nil) + .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("window1"), Symbol("window2")).analyze // When Analyzer adding Window operators after grouping the extracted Window Expressions // based on their Partition and Order Specs, the order of Window operators is @@ -1004,24 +1021,29 @@ class FilterPushdownSuite extends PlanTest { test("Window: predicate push down -- multi window specification - 2") { // partitioning clauses are different between winSpec1 and winSpec2 - val winSpec1 = windowSpec('a.attr :: Nil, 'b.asc :: Nil, UnspecifiedFrame) - val winExpr1 = windowExpr(count('b), winSpec1) - val winSpec2 = windowSpec('b.attr :: Nil, 'b.asc :: Nil, UnspecifiedFrame) - val winExpr2 = windowExpr(count('a), winSpec2) + val winSpec1 = windowSpec(Symbol("a").attr :: Nil, Symbol("b").asc :: Nil, UnspecifiedFrame) + val winExpr1 = windowExpr(count(Symbol("b")), winSpec1) + val winSpec2 = windowSpec(Symbol("b").attr :: Nil, Symbol("b").asc :: Nil, UnspecifiedFrame) + val winExpr2 = windowExpr(count(Symbol("a")), winSpec2) val originalQuery = testRelation - .select('a, winExpr1.as('window1), 'b, 'c, winExpr2.as('window2)).where('b > 1) - - val correctAnswer1 = testRelation.select('a, 'b, 'c) - .window(winExpr1.as('window1) :: Nil, 'a.attr :: Nil, 'b.asc :: Nil) - .where('b > 1) - .window(winExpr2.as('window2) :: Nil, 'b.attr :: Nil, 'b.asc :: Nil) - .select('a, 'window1, 'b, 'c, 'window2).analyze - - val correctAnswer2 = testRelation.select('a, 'b, 'c) - .window(winExpr2.as('window2) :: Nil, 'b.attr :: Nil, 'b.asc :: Nil) - .window(winExpr1.as('window1) :: Nil, 'a.attr :: Nil, 'b.asc :: Nil) - .where('b > 1) - .select('a, 'window1, 'b, 'c, 'window2).analyze + .select(Symbol("a"), winExpr1.as(Symbol("window1")), + Symbol("b"), Symbol("c"), winExpr2.as(Symbol("window2"))).where(Symbol("b") > 1) + + val correctAnswer1 = testRelation.select(Symbol("a"), Symbol("b"), Symbol("c")) + .window(winExpr1.as(Symbol("window1")) :: Nil, + Symbol("a").attr :: Nil, Symbol("b").asc :: Nil) + .where(Symbol("b") > 1) + .window(winExpr2.as(Symbol("window2")) :: Nil, + Symbol("b").attr :: Nil, Symbol("b").asc :: Nil) + .select(Symbol("a"), Symbol("window1"), Symbol("b"), Symbol("c"), Symbol("window2")).analyze + + val correctAnswer2 = testRelation.select(Symbol("a"), Symbol("b"), Symbol("c")) + .window(winExpr2.as(Symbol("window2")) :: Nil, + Symbol("b").attr :: Nil, Symbol("b").asc :: Nil) + .window(winExpr1.as(Symbol("window1")) :: Nil, + Symbol("a").attr :: Nil, Symbol("b").asc :: Nil) + .where(Symbol("b") > 1) + .select(Symbol("a"), Symbol("window1"), Symbol("b"), Symbol("c"), Symbol("window2")).analyze val optimizedQuery = Optimize.execute(originalQuery.analyze) // When Analyzer adding Window operators after grouping the extracted Window Expressions @@ -1036,13 +1058,16 @@ class FilterPushdownSuite extends PlanTest { test("Window: predicate push down -- predicates with multiple partitioning columns") { val winExpr = - windowExpr(count('b), windowSpec('a.attr :: 'b.attr :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) + windowExpr(count(Symbol("b")), windowSpec(Symbol("a").attr :: Symbol("b").attr :: Nil, + Symbol("b").asc :: Nil, UnspecifiedFrame)) - val originalQuery = testRelation.select('a, 'b, 'c, winExpr.as('window)).where('a + 'b > 1) + val originalQuery = testRelation.select(Symbol("a"), Symbol("b"), Symbol("c"), + winExpr.as(Symbol("window"))).where(Symbol("a") + Symbol("b") > 1) val correctAnswer = testRelation - .where('a + 'b > 1).select('a, 'b, 'c) - .window(winExpr.as('window) :: Nil, 'a.attr :: 'b.attr :: Nil, 'b.asc :: Nil) - .select('a, 'b, 'c, 'window).analyze + .where(Symbol("a") + Symbol("b") > 1).select(Symbol("a"), Symbol("b"), Symbol("c")) + .window(winExpr.as(Symbol("window")) :: Nil, + Symbol("a").attr :: Symbol("b").attr :: Nil, Symbol("b").asc :: Nil) + .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("window")).analyze comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer) } @@ -1052,75 +1077,88 @@ class FilterPushdownSuite extends PlanTest { // to the alias that is defined as the same expression ignore("Window: predicate push down -- complex predicate with the same expressions") { val winSpec = windowSpec( - partitionSpec = 'a.attr + 'b.attr :: Nil, - orderSpec = 'b.asc :: Nil, + partitionSpec = Symbol("a").attr + Symbol("b").attr :: Nil, + orderSpec = Symbol("b").asc :: Nil, UnspecifiedFrame) - val winExpr = windowExpr(count('b), winSpec) + val winExpr = windowExpr(count(Symbol("b")), winSpec) val winSpecAnalyzed = windowSpec( - partitionSpec = '_w0.attr :: Nil, - orderSpec = 'b.asc :: Nil, + partitionSpec = Symbol("_w0").attr :: Nil, + orderSpec = Symbol("b").asc :: Nil, UnspecifiedFrame) - val winExprAnalyzed = windowExpr(count('b), winSpecAnalyzed) + val winExprAnalyzed = windowExpr(count(Symbol("b")), winSpecAnalyzed) - val originalQuery = testRelation.select('a, 'b, 'c, winExpr.as('window)).where('a + 'b > 1) + val originalQuery = testRelation.select(Symbol("a"), Symbol("b"), Symbol("c"), + winExpr.as(Symbol("window"))).where(Symbol("a") + Symbol("b") > 1) val correctAnswer = testRelation - .where('a + 'b > 1).select('a, 'b, 'c, ('a + 'b).as("_w0")) - .window(winExprAnalyzed.as('window) :: Nil, '_w0 :: Nil, 'b.asc :: Nil) - .select('a, 'b, 'c, 'window).analyze + .where(Symbol("a") + Symbol("b") > 1) + .select(Symbol("a"), Symbol("b"), Symbol("c"), (Symbol("a") + Symbol("b")).as("_w0")) + .window(winExprAnalyzed.as(Symbol("window")) :: Nil, + Symbol("_w0") :: Nil, Symbol("b").asc :: Nil) + .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("window")).analyze comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer) } test("Window: no predicate push down -- predicates are not from partitioning keys") { val winSpec = windowSpec( - partitionSpec = 'a.attr :: 'b.attr :: Nil, - orderSpec = 'b.asc :: Nil, + partitionSpec = Symbol("a").attr :: Symbol("b").attr :: Nil, + orderSpec = Symbol("b").asc :: Nil, UnspecifiedFrame) - val winExpr = windowExpr(count('b), winSpec) + val winExpr = windowExpr(count(Symbol("b")), winSpec) // No push down: the predicate is c > 1, but the partitioning key is (a, b). - val originalQuery = testRelation.select('a, 'b, 'c, winExpr.as('window)).where('c > 1) - val correctAnswer = testRelation.select('a, 'b, 'c) - .window(winExpr.as('window) :: Nil, 'a.attr :: 'b.attr :: Nil, 'b.asc :: Nil) - .where('c > 1).select('a, 'b, 'c, 'window).analyze + val originalQuery = testRelation.select(Symbol("a"), Symbol("b"), Symbol("c"), + winExpr.as(Symbol("window"))).where(Symbol("c") > 1) + val correctAnswer = testRelation.select(Symbol("a"), Symbol("b"), Symbol("c")) + .window(winExpr.as(Symbol("window")) :: Nil, + Symbol("a").attr :: Symbol("b").attr :: Nil, Symbol("b").asc :: Nil) + .where(Symbol("c") > 1) + .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("window")).analyze comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer) } test("Window: no predicate push down -- partial compound partition key") { val winSpec = windowSpec( - partitionSpec = 'a.attr + 'b.attr :: 'b.attr :: Nil, - orderSpec = 'b.asc :: Nil, + partitionSpec = Symbol("a").attr + Symbol("b").attr :: Symbol("b").attr :: Nil, + orderSpec = Symbol("b").asc :: Nil, UnspecifiedFrame) - val winExpr = windowExpr(count('b), winSpec) + val winExpr = windowExpr(count(Symbol("b")), winSpec) // No push down: the predicate is a > 1, but the partitioning key is (a + b, b) - val originalQuery = testRelation.select('a, 'b, 'c, winExpr.as('window)).where('a > 1) + val originalQuery = testRelation.select(Symbol("a"), Symbol("b"), Symbol("c"), + winExpr.as(Symbol("window"))).where(Symbol("a") > 1) val winSpecAnalyzed = windowSpec( - partitionSpec = '_w0.attr :: 'b.attr :: Nil, - orderSpec = 'b.asc :: Nil, + partitionSpec = Symbol("_w0").attr :: Symbol("b").attr :: Nil, + orderSpec = Symbol("b").asc :: Nil, UnspecifiedFrame) - val winExprAnalyzed = windowExpr(count('b), winSpecAnalyzed) - val correctAnswer = testRelation.select('a, 'b, 'c, ('a + 'b).as("_w0")) - .window(winExprAnalyzed.as('window) :: Nil, '_w0 :: 'b.attr :: Nil, 'b.asc :: Nil) - .where('a > 1).select('a, 'b, 'c, 'window).analyze + val winExprAnalyzed = windowExpr(count(Symbol("b")), winSpecAnalyzed) + val correctAnswer = testRelation.select( + Symbol("a"), Symbol("b"), Symbol("c"), (Symbol("a") + Symbol("b")).as("_w0")) + .window(winExprAnalyzed.as(Symbol("window")) :: Nil, + Symbol("_w0") :: Symbol("b").attr :: Nil, Symbol("b").asc :: Nil) + .where(Symbol("a") > 1) + .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("window")).analyze comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer) } test("Window: no predicate push down -- complex predicates containing non partitioning columns") { val winSpec = - windowSpec(partitionSpec = 'b.attr :: Nil, orderSpec = 'b.asc :: Nil, UnspecifiedFrame) - val winExpr = windowExpr(count('b), winSpec) + windowSpec(partitionSpec = Symbol("b").attr :: Nil, + orderSpec = Symbol("b").asc :: Nil, UnspecifiedFrame) + val winExpr = windowExpr(count(Symbol("b")), winSpec) // No push down: the predicate is a + b > 1, but the partitioning key is b. - val originalQuery = testRelation.select('a, 'b, 'c, winExpr.as('window)).where('a + 'b > 1) + val originalQuery = testRelation.select(Symbol("a"), Symbol("b"), + Symbol("c"), winExpr.as(Symbol("window"))).where(Symbol("a") + Symbol("b") > 1) val correctAnswer = testRelation - .select('a, 'b, 'c) - .window(winExpr.as('window) :: Nil, 'b.attr :: Nil, 'b.asc :: Nil) - .where('a + 'b > 1).select('a, 'b, 'c, 'window).analyze + .select(Symbol("a"), Symbol("b"), Symbol("c")) + .window(winExpr.as(Symbol("window")) :: Nil, Symbol("b").attr :: Nil, Symbol("b").asc :: Nil) + .where(Symbol("a") + Symbol("b") > 1) + .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("window")).analyze comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer) } @@ -1128,52 +1166,57 @@ class FilterPushdownSuite extends PlanTest { // complex predicates with the same references but different expressions test("Window: no predicate push down -- complex predicate with different expressions") { val winSpec = windowSpec( - partitionSpec = 'a.attr + 'b.attr :: Nil, - orderSpec = 'b.asc :: Nil, + partitionSpec = Symbol("a").attr + Symbol("b").attr :: Nil, + orderSpec = Symbol("b").asc :: Nil, UnspecifiedFrame) - val winExpr = windowExpr(count('b), winSpec) + val winExpr = windowExpr(count(Symbol("b")), winSpec) val winSpecAnalyzed = windowSpec( - partitionSpec = '_w0.attr :: Nil, - orderSpec = 'b.asc :: Nil, + partitionSpec = Symbol("_w0").attr :: Nil, + orderSpec = Symbol("b").asc :: Nil, UnspecifiedFrame) - val winExprAnalyzed = windowExpr(count('b), winSpecAnalyzed) + val winExprAnalyzed = windowExpr(count(Symbol("b")), winSpecAnalyzed) // No push down: the predicate is a + b > 1, but the partitioning key is a + b. - val originalQuery = testRelation.select('a, 'b, 'c, winExpr.as('window)).where('a - 'b > 1) - val correctAnswer = testRelation.select('a, 'b, 'c, ('a + 'b).as("_w0")) - .window(winExprAnalyzed.as('window) :: Nil, '_w0 :: Nil, 'b.asc :: Nil) - .where('a - 'b > 1).select('a, 'b, 'c, 'window).analyze + val originalQuery = testRelation.select(Symbol("a"), Symbol("b"), Symbol("c"), + winExpr.as(Symbol("window"))).where(Symbol("a") - Symbol("b") > 1) + val correctAnswer = testRelation.select(Symbol("a"), Symbol("b"), + Symbol("c"), (Symbol("a") + Symbol("b")).as("_w0")) + .window(winExprAnalyzed.as(Symbol("window")) :: Nil, + Symbol("_w0") :: Nil, Symbol("b").asc :: Nil) + .where(Symbol("a") - Symbol("b") > 1) + .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("window")).analyze comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer) } test("watermark pushdown: no pushdown on watermark attribute #1") { val interval = new CalendarInterval(2, 2, 2000L) - val relation = LocalRelation(attrA, 'b.timestamp, attrC) + val relation = LocalRelation(attrA, Symbol("b").timestamp, attrC) // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. - val originalQuery = EventTimeWatermark('b, interval, relation) - .where('a === 5 && 'b === new java.sql.Timestamp(0) && 'c === 5) + val originalQuery = EventTimeWatermark(Symbol("b"), interval, relation) + .where(Symbol("a") === 5 && Symbol("b") === new java.sql.Timestamp(0) && Symbol("c") === 5) val correctAnswer = EventTimeWatermark( - 'b, interval, relation.where('a === 5 && 'c === 5)) - .where('b === new java.sql.Timestamp(0)) + Symbol("b"), interval, relation.where(Symbol("a") === 5 && Symbol("c") === 5)) + .where(Symbol("b") === new java.sql.Timestamp(0)) comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze) } test("watermark pushdown: no pushdown for nondeterministic filter") { val interval = new CalendarInterval(2, 2, 2000L) - val relation = LocalRelation(attrA, attrB, 'c.timestamp) + val relation = LocalRelation(attrA, attrB, Symbol("c").timestamp) // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. - val originalQuery = EventTimeWatermark('c, interval, relation) - .where('a === 5 && 'b === Rand(10) && 'c === new java.sql.Timestamp(0)) + val originalQuery = EventTimeWatermark(Symbol("c"), interval, relation) + .where(Symbol("a") === 5 && Symbol("b") === Rand(10) && + Symbol("c") === new java.sql.Timestamp(0)) val correctAnswer = EventTimeWatermark( - 'c, interval, relation.where('a === 5)) - .where('b === Rand(10) && 'c === new java.sql.Timestamp(0)) + Symbol("c"), interval, relation.where(Symbol("a") === 5)) + .where(Symbol("b") === Rand(10) && Symbol("c") === new java.sql.Timestamp(0)) comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, checkAnalysis = false) @@ -1181,14 +1224,14 @@ class FilterPushdownSuite extends PlanTest { test("watermark pushdown: full pushdown") { val interval = new CalendarInterval(2, 2, 2000L) - val relation = LocalRelation(attrA, attrB, 'c.timestamp) + val relation = LocalRelation(attrA, attrB, Symbol("c").timestamp) // Verify that all conditions except the watermark touching condition are pushed down // by the optimizer and others are not. - val originalQuery = EventTimeWatermark('c, interval, relation) - .where('a === 5 && 'b === 10) + val originalQuery = EventTimeWatermark(Symbol("c"), interval, relation) + .where(Symbol("a") === 5 && Symbol("b") === 10) val correctAnswer = EventTimeWatermark( - 'c, interval, relation.where('a === 5 && 'b === 10)) + Symbol("c"), interval, relation.where(Symbol("a") === 5 && Symbol("b") === 10)) comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, checkAnalysis = false) @@ -1196,12 +1239,13 @@ class FilterPushdownSuite extends PlanTest { test("watermark pushdown: no pushdown on watermark attribute #2") { val interval = new CalendarInterval(2, 2, 2000L) - val relation = LocalRelation('a.timestamp, attrB, attrC) + val relation = LocalRelation(Symbol("a").timestamp, attrB, attrC) - val originalQuery = EventTimeWatermark('a, interval, relation) - .where('a === new java.sql.Timestamp(0) && 'b === 10) + val originalQuery = EventTimeWatermark(Symbol("a"), interval, relation) + .where(Symbol("a") === new java.sql.Timestamp(0) && Symbol("b") === 10) val correctAnswer = EventTimeWatermark( - 'a, interval, relation.where('b === 10)).where('a === new java.sql.Timestamp(0)) + Symbol("a"), interval, relation.where(Symbol("b") === 10)) + .where(Symbol("a") === new java.sql.Timestamp(0)) comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze, checkAnalysis = false) @@ -1209,22 +1253,22 @@ class FilterPushdownSuite extends PlanTest { test("push down predicate through expand") { val query = - Filter('a > 1, + Filter(Symbol("a") > 1, Expand( Seq( - Seq('a, 'b, 'c, Literal.create(null, StringType), 1), - Seq('a, 'b, 'c, 'a, 2)), - Seq('a, 'b, 'c), + Seq(Symbol("a"), Symbol("b"), Symbol("c"), Literal.create(null, StringType), 1), + Seq(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("a"), 2)), + Seq(Symbol("a"), Symbol("b"), Symbol("c")), testRelation)).analyze val optimized = Optimize.execute(query) val expected = Expand( Seq( - Seq('a, 'b, 'c, Literal.create(null, StringType), 1), - Seq('a, 'b, 'c, 'a, 2)), - Seq('a, 'b, 'c), - Filter('a > 1, testRelation)).analyze + Seq(Symbol("a"), Symbol("b"), Symbol("c"), Literal.create(null, StringType), 1), + Seq(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("a"), 2)), + Seq(Symbol("a"), Symbol("b"), Symbol("c")), + Filter(Symbol("a") > 1, testRelation)).analyze comparePlans(optimized, expected) } @@ -1252,18 +1296,19 @@ class FilterPushdownSuite extends PlanTest { } test("push down filter predicates through inner join") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) - val originalQuery = x.join(y).where(("x.b".attr === "y.b".attr) && (simpleDisjunctivePredicate)) + val originalQuery = + x.join(y).where(("x.b".attr === "y.b".attr) && (simpleDisjunctivePredicate)) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, expectedPredicatePushDownResult) } test("push down join predicates through inner join") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = x.join(y, condition = Some(("x.b".attr === "y.b".attr) && (simpleDisjunctivePredicate))) @@ -1273,8 +1318,8 @@ class FilterPushdownSuite extends PlanTest { } test("push down complex predicates through inner join") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val joinCondition = (("x.b".attr === "y.b".attr) && ((("x.a".attr === 5) && ("y.a".attr >= 2) && ("y.a".attr <= 3)) @@ -1284,17 +1329,18 @@ class FilterPushdownSuite extends PlanTest { val originalQuery = x.join(y, condition = Some(joinCondition)) val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation.where( - ('a === 5 || 'a === 2 || 'a === 1)).subquery('x) + (Symbol("a") === 5 || Symbol("a") === 2 || Symbol("a") === 1)).subquery(Symbol("x")) val right = testRelation.where( - ('a >= 2 && 'a <= 3) || ('a >= 1 && 'a <= 14) || ('a >= 9 && 'a <= 27)).subquery('y) + (Symbol("a") >= 2 && Symbol("a") <= 3) || (Symbol("a") >= 1 && + Symbol("a") <= 14) || (Symbol("a") >= 9 && Symbol("a") <= 27)).subquery(Symbol("y")) val correctAnswer = left.join(right, condition = Some(joinCondition)).analyze comparePlans(optimized, correctAnswer) } test("push down predicates(with NOT predicate) through inner join") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = x.join(y, condition = Some(("x.b".attr === "y.b".attr) @@ -1302,8 +1348,8 @@ class FilterPushdownSuite extends PlanTest { && ("x.a".attr < 2 || ("y.a".attr > 13)) || ("x.a".attr > 1) && ("y.a".attr > 11)))) val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('a <= 3 || 'a >= 2).subquery('x) - val right = testRelation.subquery('y) + val left = testRelation.where(Symbol("a") <= 3 || Symbol("a") >= 2).subquery(Symbol("x")) + val right = testRelation.subquery(Symbol("y")) val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr && (("x.a".attr <= 3) || (("x.a".attr >= 2) && ("y.a".attr <= 13))) @@ -1313,16 +1359,16 @@ class FilterPushdownSuite extends PlanTest { } test("push down predicates through left join") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = x.join(y, joinType = LeftOuter, condition = Some(("x.b".attr === "y.b".attr) && simpleDisjunctivePredicate)) val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.subquery('x) - val right = testRelation.where('a > 13 || 'a > 11).subquery('y) + val left = testRelation.subquery(Symbol("x")) + val right = testRelation.where(Symbol("a") > 13 || Symbol("a") > 11).subquery(Symbol("y")) val correctAnswer = left.join(right, joinType = LeftOuter, condition = Some("x.b".attr === "y.b".attr && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) @@ -1332,16 +1378,16 @@ class FilterPushdownSuite extends PlanTest { } test("push down predicates through right join") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = x.join(y, joinType = RightOuter, condition = Some(("x.b".attr === "y.b".attr) && simpleDisjunctivePredicate)) val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('a > 3 || 'a > 1).subquery('x) - val right = testRelation.subquery('y) + val left = testRelation.where(Symbol("a") > 3 || Symbol("a") > 1).subquery(Symbol("x")) + val right = testRelation.subquery(Symbol("y")) val correctAnswer = left.join(right, joinType = RightOuter, condition = Some("x.b".attr === "y.b".attr && (("x.a".attr > 3) && ("y.a".attr > 13) || ("x.a".attr > 1) && ("y.a".attr > 11)))) @@ -1351,16 +1397,17 @@ class FilterPushdownSuite extends PlanTest { } test("SPARK-32302: avoid generating too many predicates") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = x.join(y, condition = Some(("x.b".attr === "y.b".attr) && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) || (("y.a".attr > 2) && ("y.c".attr < 1))))) val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.subquery('x) - val right = testRelation.where('c <= 5 || ('a > 2 && 'c < 1)).subquery('y) + val left = testRelation.subquery(Symbol("x")) + val right = testRelation.where(Symbol("c") <= 5 || (Symbol("a") > 2 && + Symbol("c") < 1)).subquery(Symbol("y")) val correctAnswer = left.join(right, condition = Some("x.b".attr === "y.b".attr && ((("x.a".attr > 3) && ("x.a".attr < 13) && ("y.c".attr <= 5)) || (("y.a".attr > 2) && ("y.c".attr < 1))))).analyze @@ -1369,16 +1416,16 @@ class FilterPushdownSuite extends PlanTest { } test("push down predicate through multiple joins") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) - val z = testRelation.subquery('z) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) + val z = testRelation.subquery(Symbol("z")) val xJoinY = x.join(y, condition = Some("x.b".attr === "y.b".attr)) val originalQuery = z.join(xJoinY, condition = Some("x.a".attr === "z.a".attr && simpleDisjunctivePredicate)) val optimized = Optimize.execute(originalQuery.analyze) - val left = x.where('a > 3 || 'a > 1) - val right = y.where('a > 13 || 'a > 11) + val left = x.where(Symbol("a") > 3 || Symbol("a") > 1) + val right = y.where(Symbol("a") > 13 || Symbol("a") > 11) val correctAnswer = z.join(left.join(right, condition = Some("x.b".attr === "y.b".attr && simpleDisjunctivePredicate)), condition = Some("x.a".attr === "z.a".attr)).analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index 92e4fa345e2a..d0c29c681098 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -31,84 +31,88 @@ class FoldablePropagationSuite extends PlanTest { FoldablePropagation) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int) test("Propagate from subquery") { val query = OneRowRelation() - .select(Literal(1).as('a), Literal(2).as('b)) - .subquery('T) - .select('a, 'b) + .select(Literal(1).as(Symbol("a")), Literal(2).as(Symbol("b"))) + .subquery(Symbol("T")) + .select(Symbol("a"), Symbol("b")) val optimized = Optimize.execute(query.analyze) val correctAnswer = OneRowRelation() - .select(Literal(1).as('a), Literal(2).as('b)) - .subquery('T) - .select(Literal(1).as('a), Literal(2).as('b)).analyze + .select(Literal(1).as(Symbol("a")), Literal(2).as(Symbol("b"))) + .subquery(Symbol("T")) + .select(Literal(1).as(Symbol("a")), Literal(2).as(Symbol("b"))).analyze comparePlans(optimized, correctAnswer) } test("Propagate to select clause") { val query = testRelation - .select('a.as('x), "str".as('y), 'b.as('z)) - .select('x, 'y, 'z) + .select(Symbol("a").as(Symbol("x")), "str".as(Symbol("y")), Symbol("b").as(Symbol("z"))) + .select(Symbol("x"), Symbol("y"), Symbol("z")) val optimized = Optimize.execute(query.analyze) val correctAnswer = testRelation - .select('a.as('x), "str".as('y), 'b.as('z)) - .select('x, "str".as('y), 'z).analyze + .select(Symbol("a").as(Symbol("x")), "str".as(Symbol("y")), Symbol("b").as(Symbol("z"))) + .select(Symbol("x"), "str".as(Symbol("y")), Symbol("z")).analyze comparePlans(optimized, correctAnswer) } test("Propagate to where clause") { val query = testRelation - .select("str".as('y)) - .where('y === "str" && "str" === 'y) + .select("str".as(Symbol("y"))) + .where(Symbol("y") === "str" && "str" === Symbol("y")) val optimized = Optimize.execute(query.analyze) val correctAnswer = testRelation - .select("str".as('y)) - .where("str".as('y) === "str" && "str" === "str".as('y)).analyze + .select("str".as(Symbol("y"))) + .where("str".as(Symbol("y")) === "str" && "str" === "str".as(Symbol("y"))).analyze comparePlans(optimized, correctAnswer) } test("Propagate to orderBy clause") { val query = testRelation - .select('a.as('x), Year(CurrentDate()).as('y), 'b) - .orderBy('x.asc, 'y.asc, 'b.desc) + .select(Symbol("a").as(Symbol("x")), Year(CurrentDate()).as(Symbol("y")), Symbol("b")) + .orderBy(Symbol("x").asc, Symbol("y").asc, Symbol("b").desc) val optimized = Optimize.execute(query.analyze) val correctAnswer = testRelation - .select('a.as('x), Year(CurrentDate()).as('y), 'b) - .orderBy('x.asc, SortOrder(Year(CurrentDate()), Ascending), 'b.desc).analyze + .select(Symbol("a").as(Symbol("x")), Year(CurrentDate()).as(Symbol("y")), Symbol("b")) + .orderBy(Symbol("x").asc, SortOrder(Year(CurrentDate()), Ascending), Symbol("b").desc).analyze comparePlans(optimized, correctAnswer) } test("Propagate to groupBy clause") { val query = testRelation - .select('a.as('x), Year(CurrentDate()).as('y), 'b) - .groupBy('x, 'y, 'b)(sum('x), avg('y).as('AVG), count('b)) + .select(Symbol("a").as(Symbol("x")), Year(CurrentDate()).as(Symbol("y")), Symbol("b")) + .groupBy(Symbol("x"), Symbol("y"), Symbol("b"))(sum(Symbol("x")), + avg(Symbol("y")).as(Symbol("AVG")), count(Symbol("b"))) val optimized = Optimize.execute(query.analyze) val correctAnswer = testRelation - .select('a.as('x), Year(CurrentDate()).as('y), 'b) - .groupBy('x, Year(CurrentDate()).as('y), 'b)(sum('x), avg(Year(CurrentDate())).as('AVG), - count('b)).analyze + .select(Symbol("a").as(Symbol("x")), Year(CurrentDate()).as(Symbol("y")), Symbol("b")) + .groupBy(Symbol("x"), Year(CurrentDate()).as(Symbol("y")), Symbol("b"))(sum(Symbol("x")), + avg(Year(CurrentDate())).as(Symbol("AVG")), + count(Symbol("b"))).analyze comparePlans(optimized, correctAnswer) } test("Propagate in a complex query") { val query = testRelation - .select('a.as('x), Year(CurrentDate()).as('y), 'b) - .where('x > 1 && 'y === 2016 && 'b > 1) - .groupBy('x, 'y, 'b)(sum('x), avg('y).as('AVG), count('b)) - .orderBy('x.asc, 'AVG.asc) + .select(Symbol("a").as(Symbol("x")), Year(CurrentDate()).as(Symbol("y")), Symbol("b")) + .where(Symbol("x") > 1 && Symbol("y") === 2016 && Symbol("b") > 1) + .groupBy(Symbol("x"), Symbol("y"), Symbol("b"))(sum(Symbol("x")), + avg(Symbol("y")).as(Symbol("AVG")), count(Symbol("b"))) + .orderBy(Symbol("x").asc, Symbol("AVG").asc) val optimized = Optimize.execute(query.analyze) val correctAnswer = testRelation - .select('a.as('x), Year(CurrentDate()).as('y), 'b) - .where('x > 1 && Year(CurrentDate()).as('y) === 2016 && 'b > 1) - .groupBy('x, Year(CurrentDate()).as("y"), 'b)(sum('x), avg(Year(CurrentDate())).as('AVG), - count('b)) - .orderBy('x.asc, 'AVG.asc).analyze + .select(Symbol("a").as(Symbol("x")), Year(CurrentDate()).as(Symbol("y")), Symbol("b")) + .where(Symbol("x") > 1 && Year(CurrentDate()).as(Symbol("y")) === 2016 && Symbol("b") > 1) + .groupBy(Symbol("x"), Year(CurrentDate()).as("y"), Symbol("b"))(sum(Symbol("x")), + avg(Year(CurrentDate())).as(Symbol("AVG")), + count(Symbol("b"))) + .orderBy(Symbol("x").asc, Symbol("AVG").asc).analyze comparePlans(optimized, correctAnswer) } @@ -116,27 +120,31 @@ class FoldablePropagationSuite extends PlanTest { test("Propagate in subqueries of Union queries") { val query = Union( Seq( - testRelation.select(Literal(1).as('x), 'a).select('x, 'x + 'a), - testRelation.select(Literal(2).as('x), 'a).select('x, 'x + 'a))) - .select('x) + testRelation.select(Literal(1).as(Symbol("x")), + Symbol("a")).select(Symbol("x"), Symbol("x") + Symbol("a")), + testRelation.select(Literal(2).as(Symbol("x")), + Symbol("a")).select(Symbol("x"), Symbol("x") + Symbol("a")))) + .select(Symbol("x")) val optimized = Optimize.execute(query.analyze) val correctAnswer = Union( Seq( - testRelation.select(Literal(1).as('x), 'a) - .select(Literal(1).as('x), (Literal(1).as('x) + 'a).as("(x + a)")), - testRelation.select(Literal(2).as('x), 'a) - .select(Literal(2).as('x), (Literal(2).as('x) + 'a).as("(x + a)")))) - .select('x).analyze + testRelation.select(Literal(1).as(Symbol("x")), Symbol("a")) + .select(Literal(1).as(Symbol("x")), + (Literal(1).as(Symbol("x")) + Symbol("a")).as("(x + a)")), + testRelation.select(Literal(2).as(Symbol("x")), Symbol("a")) + .select(Literal(2).as(Symbol("x")), + (Literal(2).as(Symbol("x")) + Symbol("a")).as("(x + a)")))) + .select(Symbol("x")).analyze comparePlans(optimized, correctAnswer) } test("Propagate in inner join") { - val ta = testRelation.select('a, Literal(1).as('tag)) - .union(testRelation.select('a.as('a), Literal(2).as('tag))) - .subquery('ta) - val tb = testRelation.select('a, Literal(1).as('tag)) - .union(testRelation.select('a.as('a), Literal(2).as('tag))) - .subquery('tb) + val ta = testRelation.select(Symbol("a"), Literal(1).as(Symbol("tag"))) + .union(testRelation.select(Symbol("a").as(Symbol("a")), Literal(2).as(Symbol("tag")))) + .subquery(Symbol("ta")) + val tb = testRelation.select(Symbol("a"), Literal(1).as(Symbol("tag"))) + .union(testRelation.select(Symbol("a").as(Symbol("a")), Literal(2).as(Symbol("tag")))) + .subquery(Symbol("tb")) val query = ta.join(tb, Inner, Some("ta.a".attr === "tb.a".attr && "ta.tag".attr === "tb.tag".attr)) val optimized = Optimize.execute(query.analyze) @@ -145,12 +153,12 @@ class FoldablePropagationSuite extends PlanTest { } test("Propagate in expand") { - val c1 = Literal(1).as('a) - val c2 = Literal(2).as('b) + val c1 = Literal(1).as(Symbol("a")) + val c2 = Literal(2).as(Symbol("b")) val a1 = c1.toAttribute.newInstance().withNullability(true) val a2 = c2.toAttribute.newInstance().withNullability(true) val expand = Expand( - Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))), + Seq(Seq(Literal(null), Symbol("b")), Seq(Symbol("a"), Literal(null))), Seq(a1, a2), OneRowRelation().select(c1, c2)) val query = expand.where(a1.isNotNull).select(a1, a2).analyze @@ -163,30 +171,32 @@ class FoldablePropagationSuite extends PlanTest { } test("Propagate above outer join") { - val left = LocalRelation('a.int).select('a, Literal(1).as('b)) - val right = LocalRelation('c.int).select('c, Literal(1).as('d)) + val left = LocalRelation(Symbol("a").int).select(Symbol("a"), Literal(1).as(Symbol("b"))) + val right = LocalRelation(Symbol("c").int).select(Symbol("c"), Literal(1).as(Symbol("d"))) val join = left.join( right, joinType = LeftOuter, - condition = Some('a === 'c && 'b === 'd)) - val query = join.select(('b + 3).as('res)).analyze + condition = Some(Symbol("a") === Symbol("c") && Symbol("b") === Symbol("d"))) + val query = join.select((Symbol("b") + 3).as(Symbol("res"))).analyze val optimized = Optimize.execute(query) val correctAnswer = left.join( right, joinType = LeftOuter, - condition = Some('a === 'c && Literal(1) === Literal(1))) - .select((Literal(1) + 3).as('res)).analyze + condition = Some(Symbol("a") === Symbol("c") && Literal(1) === Literal(1))) + .select((Literal(1) + 3).as(Symbol("res"))).analyze comparePlans(optimized, correctAnswer) } test("SPARK-32635: Replace references with foldables coming only from the node's children") { - val leftExpression = 'a.int - val left = LocalRelation(leftExpression).select('a) + val leftExpression = Symbol("a").int + val left = LocalRelation(leftExpression).select(Symbol("a")) val rightExpression = Alias(Literal(2), "a")(leftExpression.exprId) - val right = LocalRelation('b.int).select('b, rightExpression).select('b) - val join = left.join(right, joinType = LeftOuter, condition = Some('b === 'a)) + val right = LocalRelation(Symbol("b").int).select(Symbol("b"), + rightExpression).select(Symbol("b")) + val join = + left.join(right, joinType = LeftOuter, condition = Some(Symbol("b") === Symbol("a"))) val query = join.analyze val optimized = Optimize.execute(query) @@ -195,13 +205,15 @@ class FoldablePropagationSuite extends PlanTest { test("SPARK-32951: Foldable propagation from Aggregate") { val query = testRelation - .groupBy('a)('a, sum('b).as('b), Literal(1).as('c)) - .select('a, 'b, 'c) + .groupBy(Symbol("a"))(Symbol("a"), sum(Symbol("b")).as(Symbol("b")), + Literal(1).as(Symbol("c"))) + .select(Symbol("a"), Symbol("b"), Symbol("c")) val optimized = Optimize.execute(query.analyze) val correctAnswer = testRelation - .groupBy('a)('a, sum('b).as('b), Literal(1).as('c)) - .select('a, 'b, Literal(1).as('c)).analyze + .groupBy(Symbol("a"))(Symbol("a"), sum(Symbol("b")).as(Symbol("b")), + Literal(1).as(Symbol("c"))) + .select(Symbol("a"), Symbol("b"), Literal(1).as(Symbol("c"))).analyze comparePlans(optimized, correctAnswer) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 79bd573f1d84..8638786b7fc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -40,7 +40,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { PruneFilters) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) private def testConstraintsAfterJoin( x: LogicalPlan, @@ -56,46 +56,51 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } test("filter: filter out constraints in condition") { - val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze + val originalQuery = + testRelation.where(Symbol("a") === 1 && Symbol("a") === Symbol("b")).analyze val correctAnswer = testRelation - .where(IsNotNull('a) && IsNotNull('b) && 'a === 'b && 'a === 1 && 'b === 1).analyze + .where(IsNotNull(Symbol("a")) && IsNotNull(Symbol("b")) && + Symbol("a") === Symbol("b") && Symbol("a") === 1 && Symbol("b") === 1).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } test("single inner join: filter out values on either side on equi-join keys") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = x.join(y, condition = Some(("x.a".attr === "y.a".attr) && ("x.a".attr === 1) && ("y.c".attr > 5))) .analyze - val left = x.where(IsNotNull('a) && "x.a".attr === 1) - val right = y.where(IsNotNull('a) && IsNotNull('c) && "y.c".attr > 5 && "y.a".attr === 1) + val left = x.where(IsNotNull(Symbol("a")) && "x.a".attr === 1) + val right = y.where(IsNotNull(Symbol("a")) && + IsNotNull(Symbol("c")) && "y.c".attr > 5 && "y.a".attr === 1) val correctAnswer = left.join(right, condition = Some("x.a".attr === "y.a".attr)).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } test("single inner join: filter out nulls on either side on non equal keys") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = x.join(y, condition = Some(("x.a".attr =!= "y.a".attr) && ("x.b".attr === 1) && ("y.c".attr > 5))) .analyze - val left = x.where(IsNotNull('a) && IsNotNull('b) && "x.b".attr === 1) - val right = y.where(IsNotNull('a) && IsNotNull('c) && "y.c".attr > 5) + val left = x.where(IsNotNull(Symbol("a")) && IsNotNull(Symbol("b")) && "x.b".attr === 1) + val right = y.where(IsNotNull(Symbol("a")) && IsNotNull(Symbol("c")) && "y.c".attr > 5) val correctAnswer = left.join(right, condition = Some("x.a".attr =!= "y.a".attr)).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } test("single inner join with pre-existing filters: filter out values on either side") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) - val originalQuery = x.where('b > 5).join(y.where('a === 10), + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) + val originalQuery = x.where(Symbol("b") > 5).join(y.where(Symbol("a") === 10), condition = Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).analyze - val left = x.where(IsNotNull('a) && 'a === 10 && IsNotNull('b) && 'b > 5) - val right = y.where(IsNotNull('a) && IsNotNull('b) && 'a === 10 && 'b > 5) + val left = x.where(IsNotNull(Symbol("a")) && + Symbol("a") === 10 && IsNotNull(Symbol("b")) && Symbol("b") > 5) + val right = y.where(IsNotNull(Symbol("a")) && + IsNotNull(Symbol("b")) && Symbol("a") === 10 && Symbol("b") > 5) val correctAnswer = left.join(right, condition = Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).analyze val optimized = Optimize.execute(originalQuery) @@ -103,8 +108,8 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } test("single outer join: no null filters are generated") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = x.join(y, FullOuter, condition = Some("x.a".attr === "y.a".attr)).analyze val optimized = Optimize.execute(originalQuery) @@ -112,47 +117,53 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } test("multiple inner joins: filter out values on all sides on equi-join keys") { - val t1 = testRelation.subquery('t1) - val t2 = testRelation.subquery('t2) - val t3 = testRelation.subquery('t3) - val t4 = testRelation.subquery('t4) + val t1 = testRelation.subquery(Symbol("t1")) + val t2 = testRelation.subquery(Symbol("t2")) + val t3 = testRelation.subquery(Symbol("t3")) + val t4 = testRelation.subquery(Symbol("t4")) - val originalQuery = t1.where('b > 5) + val originalQuery = t1.where(Symbol("b") > 5) .join(t2, condition = Some("t1.b".attr === "t2.b".attr)) .join(t3, condition = Some("t2.b".attr === "t3.b".attr)) .join(t4, condition = Some("t3.b".attr === "t4.b".attr)).analyze - val correctAnswer = t1.where(IsNotNull('b) && 'b > 5) - .join(t2.where(IsNotNull('b) && 'b > 5), condition = Some("t1.b".attr === "t2.b".attr)) - .join(t3.where(IsNotNull('b) && 'b > 5), condition = Some("t2.b".attr === "t3.b".attr)) - .join(t4.where(IsNotNull('b) && 'b > 5), condition = Some("t3.b".attr === "t4.b".attr)) + val correctAnswer = t1.where(IsNotNull(Symbol("b")) && Symbol("b") > 5) + .join(t2.where(IsNotNull(Symbol("b")) && Symbol("b") > 5), + condition = Some("t1.b".attr === "t2.b".attr)) + .join(t3.where(IsNotNull(Symbol("b")) && Symbol("b") > 5), + condition = Some("t2.b".attr === "t3.b".attr)) + .join(t4.where(IsNotNull(Symbol("b")) && Symbol("b") > 5), + condition = Some("t3.b".attr === "t4.b".attr)) .analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } test("inner join with filter: filter out values on all sides on equi-join keys") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val originalQuery = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).where("x.a".attr > 5).analyze - val correctAnswer = x.where(IsNotNull('a) && 'a.attr > 5) - .join(y.where(IsNotNull('a) && 'a.attr > 5), Inner, Some("x.a".attr === "y.a".attr)).analyze + val correctAnswer = x.where(IsNotNull(Symbol("a")) && Symbol("a").attr > 5) + .join(y.where(IsNotNull(Symbol("a")) && Symbol("a").attr > 5), + Inner, Some("x.a".attr === "y.a".attr)).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } test("inner join with alias: alias contains multiple attributes") { - val t1 = testRelation.subquery('t1) - val t2 = testRelation.subquery('t2) + val t1 = testRelation.subquery(Symbol("t1")) + val t2 = testRelation.subquery(Symbol("t2")) - val originalQuery = t1.select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t") + val originalQuery = t1.select(Symbol("a"), Coalesce(Seq(Symbol("a"), + Symbol("b"))).as(Symbol("int_col"))).as("t") .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) .analyze val correctAnswer = t1 - .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b))) - .select('a, Coalesce(Seq('a, 'b)).as('int_col)).as("t") - .join(t2.where(IsNotNull('a)), Inner, + .where(IsNotNull(Symbol("a")) && IsNotNull(Coalesce(Seq(Symbol("a"), Symbol("b")))) && + Symbol("a") === Coalesce(Seq(Symbol("a"), Symbol("b")))) + .select(Symbol("a"), Coalesce(Seq(Symbol("a"), Symbol("b"))).as(Symbol("int_col"))).as("t") + .join(t2.where(IsNotNull(Symbol("a"))), Inner, Some("t.a".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) .analyze val optimized = Optimize.execute(originalQuery) @@ -160,16 +171,16 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } test("inner join with alias: alias contains single attributes") { - val t1 = testRelation.subquery('t1) - val t2 = testRelation.subquery('t2) + val t1 = testRelation.subquery(Symbol("t1")) + val t2 = testRelation.subquery(Symbol("t2")) - val originalQuery = t1.select('a, 'b.as('d)).as("t") + val originalQuery = t1.select(Symbol("a"), Symbol("b").as(Symbol("d"))).as("t") .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) .analyze val correctAnswer = t1 - .where(IsNotNull('a) && IsNotNull('b) &&'a === 'b) - .select('a, 'b.as('d)).as("t") - .join(t2.where(IsNotNull('a)), Inner, + .where(IsNotNull(Symbol("a")) && IsNotNull(Symbol("b")) &&Symbol("a") === Symbol("b")) + .select(Symbol("a"), Symbol("b").as(Symbol("d"))).as("t") + .join(t2.where(IsNotNull(Symbol("a"))), Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) .analyze val optimized = Optimize.execute(originalQuery) @@ -177,29 +188,34 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } test("generate correct filters for alias that don't produce recursive constraints") { - val t1 = testRelation.subquery('t1) + val t1 = testRelation.subquery(Symbol("t1")) - val originalQuery = t1.select('a.as('x), 'b.as('y)).where('x === 1 && 'x === 'y).analyze + val originalQuery = t1.select(Symbol("a").as(Symbol("x")), + Symbol("b").as(Symbol("y"))).where(Symbol("x") === 1 && Symbol("x") === Symbol("y")).analyze val correctAnswer = - t1.where('a === 1 && 'b === 1 && 'a === 'b && IsNotNull('a) && IsNotNull('b)) - .select('a.as('x), 'b.as('y)).analyze + t1.where(Symbol("a") === 1 && Symbol("b") === 1 && Symbol("a") === Symbol("b") && + IsNotNull(Symbol("a")) && IsNotNull(Symbol("b"))) + .select(Symbol("a").as(Symbol("x")), Symbol("b").as(Symbol("y"))).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } test("No inferred filter when constraint propagation is disabled") { withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { - val originalQuery = testRelation.where('a === 1 && 'a === 'b).analyze + val originalQuery = + testRelation.where(Symbol("a") === 1 && Symbol("a") === Symbol("b")).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, originalQuery) } } test("constraints should be inferred from aliased literals") { - val originalLeft = testRelation.subquery('left).as("left") - val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a <=> 2).as("left") + val originalLeft = testRelation.subquery(Symbol("left")).as("left") + val optimizedLeft = testRelation.subquery( + Symbol("left")).where(IsNotNull(Symbol("a")) &&Symbol("a") <=> 2).as("left") - val right = Project(Seq(Literal(2).as("two")), testRelation.subquery('right)).as("right") + val right = Project( + Seq(Literal(2).as("two")), testRelation.subquery(Symbol("right"))).as("right") val condition = Some("left.a".attr === "right.two".attr) val original = originalLeft.join(right, Inner, condition) @@ -209,70 +225,73 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } test("SPARK-23405: left-semi equal-join should filter out null join keys on both sides") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) - testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y.where(IsNotNull('a)), LeftSemi) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) + testConstraintsAfterJoin( + x, y, x.where(IsNotNull(Symbol("a"))), y.where(IsNotNull(Symbol("a"))), LeftSemi) } test("SPARK-21479: Outer join after-join filters push down to null-supplying side") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val condition = Some("x.a".attr === "y.a".attr) val originalQuery = x.join(y, LeftOuter, condition).where("x.a".attr === 2).analyze - val left = x.where(IsNotNull('a) && 'a === 2) - val right = y.where(IsNotNull('a) && 'a === 2) + val left = x.where(IsNotNull(Symbol("a")) && Symbol("a") === 2) + val right = y.where(IsNotNull(Symbol("a")) && Symbol("a") === 2) val correctAnswer = left.join(right, LeftOuter, condition).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } test("SPARK-21479: Outer join pre-existing filters push down to null-supplying side") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val condition = Some("x.a".attr === "y.a".attr) val originalQuery = x.join(y.where("y.a".attr > 5), RightOuter, condition).analyze - val left = x.where(IsNotNull('a) && 'a > 5) - val right = y.where(IsNotNull('a) && 'a > 5) + val left = x.where(IsNotNull(Symbol("a")) && Symbol("a") > 5) + val right = y.where(IsNotNull(Symbol("a")) && Symbol("a") > 5) val correctAnswer = left.join(right, RightOuter, condition).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) } test("SPARK-21479: Outer join no filter push down to preserved side") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) testConstraintsAfterJoin( x, y.where("a".attr === 1), - x, y.where(IsNotNull('a) && 'a === 1), + x, y.where(IsNotNull(Symbol("a")) && Symbol("a") === 1), LeftOuter) } test("SPARK-23564: left anti join should filter out null join keys on right side") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) - testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftAnti) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull(Symbol("a"))), LeftAnti) } test("SPARK-23564: left outer join should filter out null join keys on right side") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) - testConstraintsAfterJoin(x, y, x, y.where(IsNotNull('a)), LeftOuter) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) + testConstraintsAfterJoin(x, y, x, y.where(IsNotNull(Symbol("a"))), LeftOuter) } test("SPARK-23564: right outer join should filter out null join keys on left side") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) - testConstraintsAfterJoin(x, y, x.where(IsNotNull('a)), y, RightOuter) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) + testConstraintsAfterJoin(x, y, x.where(IsNotNull(Symbol("a"))), y, RightOuter) } test("Constraints should be inferred from cast equality constraint(filter higher data type)") { - val testRelation1 = LocalRelation('a.int) - val testRelation2 = LocalRelation('b.long) - val originalLeft = testRelation1.subquery('left) - val originalRight = testRelation2.where('b === 1L).subquery('right) + val testRelation1 = LocalRelation(Symbol("a").int) + val testRelation2 = LocalRelation(Symbol("b").long) + val originalLeft = testRelation1.subquery(Symbol("left")) + val originalRight = testRelation2.where(Symbol("b") === 1L).subquery(Symbol("right")) - val left = testRelation1.where(IsNotNull('a) && 'a.cast(LongType) === 1L).subquery('left) - val right = testRelation2.where(IsNotNull('b) && 'b === 1L).subquery('right) + val left = testRelation1.where(IsNotNull(Symbol("a")) && + Symbol("a").cast(LongType) === 1L).subquery(Symbol("left")) + val right = testRelation2.where(IsNotNull(Symbol("b")) && + Symbol("b") === 1L).subquery(Symbol("right")) Seq(Some("left.a".attr.cast(LongType) === "right.b".attr), Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition => @@ -284,7 +303,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest { testConstraintsAfterJoin( originalLeft, originalRight, - testRelation1.where(IsNotNull('a)).subquery('left), + testRelation1.where(IsNotNull(Symbol("a"))).subquery(Symbol("left")), right, Inner, condition) @@ -292,13 +311,14 @@ class InferFiltersFromConstraintsSuite extends PlanTest { } test("Constraints shouldn't be inferred from cast equality constraint(filter lower data type)") { - val testRelation1 = LocalRelation('a.int) - val testRelation2 = LocalRelation('b.long) - val originalLeft = testRelation1.where('a === 1).subquery('left) - val originalRight = testRelation2.subquery('right) + val testRelation1 = LocalRelation(Symbol("a").int) + val testRelation2 = LocalRelation(Symbol("b").long) + val originalLeft = testRelation1.where(Symbol("a") === 1).subquery(Symbol("left")) + val originalRight = testRelation2.subquery(Symbol("right")) - val left = testRelation1.where(IsNotNull('a) && 'a === 1).subquery('left) - val right = testRelation2.where(IsNotNull('b)).subquery('right) + val left = testRelation1.where(IsNotNull(Symbol("a")) && + Symbol("a") === 1).subquery(Symbol("left")) + val right = testRelation2.where(IsNotNull(Symbol("b"))).subquery(Symbol("right")) Seq(Some("left.a".attr.cast(LongType) === "right.b".attr), Some("right.b".attr === "left.a".attr.cast(LongType))).foreach { condition => @@ -311,7 +331,8 @@ class InferFiltersFromConstraintsSuite extends PlanTest { originalLeft, originalRight, left, - testRelation2.where(IsNotNull('b) && 'b.attr.cast(IntegerType) === 1).subquery('right), + testRelation2.where(IsNotNull(Symbol("b")) && + Symbol("b").attr.cast(IntegerType) === 1).subquery(Symbol("right")), Inner, condition) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala index 93a1d414ed40..75f226f8ffed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala @@ -31,17 +31,17 @@ class InferFiltersFromGenerateSuite extends PlanTest { val batches = Batch("Infer Filters", Once, InferFiltersFromGenerate) :: Nil } - val testRelation = LocalRelation('a.array(StructType(Seq( + val testRelation = LocalRelation(Symbol("a").array(StructType(Seq( StructField("x", IntegerType), StructField("y", IntegerType) - ))), 'c1.string, 'c2.string) + ))), Symbol("c1").string, Symbol("c2").string) Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f => - val generator = f('a) + val generator = f(Symbol("a")) test("Infer filters from " + generator) { val originalQuery = testRelation.generate(generator).analyze val correctAnswer = testRelation - .where(IsNotNull('a) && Size('a) > 0) + .where(IsNotNull(Symbol("a")) && Size(Symbol("a")) > 0) .generate(generator) .analyze val optimized = Optimize.execute(originalQuery) @@ -50,7 +50,7 @@ class InferFiltersFromGenerateSuite extends PlanTest { test("Don't infer duplicate filters from " + generator) { val originalQuery = testRelation - .where(IsNotNull('a) && Size('a) > 0) + .where(IsNotNull(Symbol("a")) && Size(Symbol("a")) > 0) .generate(generator) .analyze val optimized = Optimize.execute(originalQuery) @@ -89,13 +89,13 @@ class InferFiltersFromGenerateSuite extends PlanTest { } Seq(Explode(_), PosExplode(_)).foreach { f => - val createArrayExplode = f(CreateArray(Seq('c1))) + val createArrayExplode = f(CreateArray(Seq(Symbol("c1")))) test("SPARK-33544: Don't infer filters from CreateArray " + createArrayExplode) { val originalQuery = testRelation.generate(createArrayExplode).analyze val optimized = OptimizeInferAndConstantFold.execute(originalQuery) comparePlans(optimized, originalQuery) } - val createMapExplode = f(CreateMap(Seq('c1, 'c2))) + val createMapExplode = f(CreateMap(Seq(Symbol("c1"), Symbol("c2")))) test("SPARK-33544: Don't infer filters from CreateMap " + createMapExplode) { val originalQuery = testRelation.generate(createMapExplode).analyze val optimized = OptimizeInferAndConstantFold.execute(originalQuery) @@ -104,7 +104,7 @@ class InferFiltersFromGenerateSuite extends PlanTest { } Seq(Inline(_)).foreach { f => - val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq('c1))))) + val createArrayStructExplode = f(CreateArray(Seq(CreateStruct(Seq(Symbol("c1")))))) test("SPARK-33544: Don't infer filters from CreateArray " + createArrayStructExplode) { val originalQuery = testRelation.generate(createArrayStructExplode).analyze val optimized = OptimizeInferAndConstantFold.execute(originalQuery) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 3d81c567eff1..67336a8ca315 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -44,13 +44,13 @@ class JoinOptimizationSuite extends PlanTest { } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation1 = LocalRelation('d.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) + val testRelation1 = LocalRelation(Symbol("d").int) test("extract filters and joins") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) - val z = testRelation.subquery('z) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) + val z = testRelation.subquery(Symbol("z")) def testExtract(plan: LogicalPlan, expected: Option[(Seq[LogicalPlan], Seq[Expression])]): Unit = { @@ -96,9 +96,9 @@ class JoinOptimizationSuite extends PlanTest { } test("reorder inner joins") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) - val z = testRelation.subquery('z) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) + val z = testRelation.subquery(Symbol("z")) val queryAnswers = Seq( ( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala index 3513cfa14808..0ada38e139b4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinSelectionHelperSuite.scala @@ -27,13 +27,13 @@ import org.apache.spark.sql.internal.SQLConf class JoinSelectionHelperSuite extends PlanTest with JoinSelectionHelper { private val left = StatsTestPlan( - outputList = Seq('a.int, 'b.int, 'c.int), + outputList = Seq(Symbol("a").int, Symbol("b").int, Symbol("c").int), rowCount = 20000000, size = Some(20000000), attributeStats = AttributeMap(Seq())) private val right = StatsTestPlan( - outputList = Seq('d.int), + outputList = Seq(Symbol("d").int), rowCount = 1000, size = Some(1000), attributeStats = AttributeMap(Seq())) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala index 1672c6d91660..07d13470deb7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LeftSemiAntiJoinPushDownSuite.scala @@ -42,41 +42,41 @@ class LeftSemiPushdownSuite extends PlanTest { CollapseProject) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation1 = LocalRelation('d.int) - val testRelation2 = LocalRelation('e.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) + val testRelation1 = LocalRelation(Symbol("d").int) + val testRelation2 = LocalRelation(Symbol("e").int) test("Project: LeftSemiAnti join pushdown") { val originalQuery = testRelation .select(star()) - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .join(testRelation1, joinType = LeftSemi, condition = Some(Symbol("b") === Symbol("d"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) - .select('a, 'b, 'c) + .join(testRelation1, joinType = LeftSemi, condition = Some(Symbol("b") === Symbol("d"))) + .select(Symbol("a"), Symbol("b"), Symbol("c")) .analyze comparePlans(optimized, correctAnswer) } test("Project: LeftSemiAnti join no pushdown because of non-deterministic proj exprs") { val originalQuery = testRelation - .select(Rand(1), 'b, 'c) - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .select(Rand(1), Symbol("b"), Symbol("c")) + .join(testRelation1, joinType = LeftSemi, condition = Some(Symbol("b") === Symbol("d"))) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) } test("Project: LeftSemiAnti join non correlated scalar subq") { - val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze) + val subq = ScalarSubquery(testRelation.groupBy(Symbol("b"))(sum(Symbol("c")).as("sum")).analyze) val originalQuery = testRelation .select(subq.as("sum")) - .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd)) + .join(testRelation1, joinType = LeftSemi, condition = Some(Symbol("sum") === Symbol("d"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .join(testRelation1, joinType = LeftSemi, condition = Some(subq === 'd)) + .join(testRelation1, joinType = LeftSemi, condition = Some(subq === Symbol("d"))) .select(subq.as("sum")) .analyze @@ -84,12 +84,13 @@ class LeftSemiPushdownSuite extends PlanTest { } test("Project: LeftSemiAnti join no pushdown - correlated scalar subq in projection list") { - val testRelation2 = LocalRelation('e.int, 'f.int) - val subqPlan = testRelation2.groupBy('e)(sum('f).as("sum")).where('e === 'a) + val testRelation2 = LocalRelation(Symbol("e").int, Symbol("f").int) + val subqPlan = testRelation2.groupBy(Symbol("e"))(sum(Symbol("f")).as("sum")) + .where(Symbol("e") === Symbol("a")) val subqExpr = ScalarSubquery(subqPlan) val originalQuery = testRelation .select(subqExpr.as("sum")) - .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd)) + .join(testRelation1, joinType = LeftSemi, condition = Some(Symbol("sum") === Symbol("d"))) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) @@ -97,13 +98,13 @@ class LeftSemiPushdownSuite extends PlanTest { test("Aggregate: LeftSemiAnti join pushdown") { val originalQuery = testRelation - .groupBy('b)('b, sum('c)) - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .groupBy(Symbol("b"))(Symbol("b"), sum(Symbol("c"))) + .join(testRelation1, joinType = LeftSemi, condition = Some(Symbol("b") === Symbol("d"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) - .groupBy('b)('b, sum('c)) + .join(testRelation1, joinType = LeftSemi, condition = Some(Symbol("b") === Symbol("d"))) + .groupBy(Symbol("b"))(Symbol("b"), sum(Symbol("c"))) .analyze comparePlans(optimized, correctAnswer) @@ -111,8 +112,8 @@ class LeftSemiPushdownSuite extends PlanTest { test("Aggregate: LeftSemiAnti join no pushdown due to non-deterministic aggr expressions") { val originalQuery = testRelation - .groupBy('b)('b, Rand(10).as('c)) - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .groupBy(Symbol("b"))(Symbol("b"), Rand(10).as(Symbol("c"))) + .join(testRelation1, joinType = LeftSemi, condition = Some(Symbol("b") === Symbol("d"))) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) @@ -120,14 +121,15 @@ class LeftSemiPushdownSuite extends PlanTest { test("Aggregate: LeftSemi join partial pushdown") { val originalQuery = testRelation - .groupBy('b)('b, sum('c).as('sum)) - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 10)) + .groupBy(Symbol("b"))(Symbol("b"), sum(Symbol("c")).as(Symbol("sum"))) + .join(testRelation1, joinType = LeftSemi, + condition = Some(Symbol("b") === Symbol("d") && Symbol("sum") === 10)) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) - .groupBy('b)('b, sum('c).as('sum)) - .where('sum === 10) + .join(testRelation1, joinType = LeftSemi, condition = Some(Symbol("b") === Symbol("d"))) + .groupBy(Symbol("b"))(Symbol("b"), sum(Symbol("c")).as(Symbol("sum"))) + .where(Symbol("sum") === 10) .analyze comparePlans(optimized, correctAnswer) @@ -135,8 +137,9 @@ class LeftSemiPushdownSuite extends PlanTest { test("Aggregate: LeftAnti join no pushdown") { val originalQuery = testRelation - .groupBy('b)('b, sum('c).as('sum)) - .join(testRelation1, joinType = LeftAnti, condition = Some('b === 'd && 'sum === 10)) + .groupBy(Symbol("b"))(Symbol("b"), sum(Symbol("c")).as(Symbol("sum"))) + .join(testRelation1, joinType = LeftAnti, + condition = Some(Symbol("b") === Symbol("d") && Symbol("sum") === 10)) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) @@ -144,42 +147,46 @@ class LeftSemiPushdownSuite extends PlanTest { test("LeftSemiAnti join over aggregate - no pushdown") { val originalQuery = testRelation - .groupBy('b)('b, sum('c).as('sum)) - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'sum === 'd)) + .groupBy(Symbol("b"))(Symbol("b"), sum(Symbol("c")).as(Symbol("sum"))) + .join(testRelation1, joinType = LeftSemi, + condition = Some(Symbol("b") === Symbol("d") && Symbol("sum") === Symbol("d"))) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) } test("Aggregate: LeftSemiAnti join non-correlated scalar subq aggr exprs") { - val subq = ScalarSubquery(testRelation.groupBy('b)(sum('c).as("sum")).analyze) + val subq = ScalarSubquery(testRelation.groupBy(Symbol("b"))(sum(Symbol("c")).as("sum")).analyze) val originalQuery = testRelation - .groupBy('a) ('a, subq.as("sum")) - .join(testRelation1, joinType = LeftSemi, condition = Some('sum === 'd && 'a === 'd)) + .groupBy(Symbol("a")) (Symbol("a"), subq.as("sum")) + .join(testRelation1, joinType = LeftSemi, + condition = Some(Symbol("sum") === Symbol("d") && Symbol("a") === Symbol("d"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .join(testRelation1, joinType = LeftSemi, condition = Some(subq === 'd && 'a === 'd)) - .groupBy('a) ('a, subq.as("sum")) + .join(testRelation1, joinType = LeftSemi, + condition = Some(subq === Symbol("d") && Symbol("a") === Symbol("d"))) + .groupBy(Symbol("a")) (Symbol("a"), subq.as("sum")) .analyze comparePlans(optimized, correctAnswer) } test("LeftSemiAnti join over Window") { - val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) + val winExpr = windowExpr(count(Symbol("b")), + windowSpec(Symbol("a") :: Nil, Symbol("b").asc :: Nil, UnspecifiedFrame)) val originalQuery = testRelation - .select('a, 'b, 'c, winExpr.as('window)) - .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)) + .select(Symbol("a"), Symbol("b"), Symbol("c"), winExpr.as(Symbol("window"))) + .join(testRelation1, joinType = LeftSemi, condition = Some(Symbol("a") === Symbol("d"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)) - .select('a, 'b, 'c) - .window(winExpr.as('window) :: Nil, 'a :: Nil, 'b.asc :: Nil) - .select('a, 'b, 'c, 'window).analyze + .join(testRelation1, joinType = LeftSemi, condition = Some(Symbol("a") === Symbol("d"))) + .select(Symbol("a"), Symbol("b"), Symbol("c")) + .window(winExpr.as(Symbol("window")) :: Nil, Symbol("a") :: Nil, Symbol("b").asc :: Nil) + .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("window")).analyze comparePlans(optimized, correctAnswer) } @@ -187,20 +194,22 @@ class LeftSemiPushdownSuite extends PlanTest { test("Window: LeftSemi partial pushdown") { // Attributes from join condition which does not refer to the window partition spec // are kept up in the plan as a Filter operator above Window. - val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) + val winExpr = windowExpr(count(Symbol("b")), + windowSpec(Symbol("a") :: Nil, Symbol("b").asc :: Nil, UnspecifiedFrame)) val originalQuery = testRelation - .select('a, 'b, 'c, winExpr.as('window)) - .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd && 'b > 5)) + .select(Symbol("a"), Symbol("b"), Symbol("c"), winExpr.as(Symbol("window"))) + .join(testRelation1, joinType = LeftSemi, + condition = Some(Symbol("a") === Symbol("d") && Symbol("b") > 5)) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)) - .select('a, 'b, 'c) - .window(winExpr.as('window) :: Nil, 'a :: Nil, 'b.asc :: Nil) - .where('b > 5) - .select('a, 'b, 'c, 'window).analyze + .join(testRelation1, joinType = LeftSemi, condition = Some(Symbol("a") === Symbol("d"))) + .select(Symbol("a"), Symbol("b"), Symbol("c")) + .window(winExpr.as(Symbol("window")) :: Nil, Symbol("a") :: Nil, Symbol("b").asc :: Nil) + .where(Symbol("b") > 5) + .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("window")).analyze comparePlans(optimized, correctAnswer) } @@ -208,43 +217,48 @@ class LeftSemiPushdownSuite extends PlanTest { test("Window: LeftAnti no pushdown") { // Attributes from join condition which does not refer to the window partition spec // are kept up in the plan as a Filter operator above Window. - val winExpr = windowExpr(count('b), windowSpec('a :: Nil, 'b.asc :: Nil, UnspecifiedFrame)) + val winExpr = windowExpr(count(Symbol("b")), + windowSpec(Symbol("a") :: Nil, Symbol("b").asc :: Nil, UnspecifiedFrame)) val originalQuery = testRelation - .select('a, 'b, 'c, winExpr.as('window)) - .join(testRelation1, joinType = LeftAnti, condition = Some('a === 'd && 'b > 5)) + .select(Symbol("a"), Symbol("b"), Symbol("c"), winExpr.as(Symbol("window"))) + .join(testRelation1, joinType = LeftAnti, + condition = Some(Symbol("a") === Symbol("d") && Symbol("b") > 5)) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a, 'b, 'c) - .window(winExpr.as('window) :: Nil, 'a :: Nil, 'b.asc :: Nil) - .join(testRelation1, joinType = LeftAnti, condition = Some('a === 'd && 'b > 5)) - .select('a, 'b, 'c, 'window).analyze + .select(Symbol("a"), Symbol("b"), Symbol("c")) + .window(winExpr.as(Symbol("window")) :: Nil, Symbol("a") :: Nil, Symbol("b").asc :: Nil) + .join(testRelation1, joinType = LeftAnti, + condition = Some(Symbol("a") === Symbol("d") && Symbol("b") > 5)) + .select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("window")).analyze comparePlans(optimized, correctAnswer) } test("Union: LeftSemiAnti join pushdown") { - val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int) + val testRelation2 = LocalRelation(Symbol("x").int, Symbol("y").int, Symbol("z").int) val originalQuery = Union(Seq(testRelation, testRelation2)) - .join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)) + .join(testRelation1, joinType = LeftSemi, condition = Some(Symbol("a") === Symbol("d"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = Union(Seq( - testRelation.join(testRelation1, joinType = LeftSemi, condition = Some('a === 'd)), - testRelation2.join(testRelation1, joinType = LeftSemi, condition = Some('x === 'd)))) + testRelation.join(testRelation1, joinType = LeftSemi, + condition = Some(Symbol("a") === Symbol("d"))), + testRelation2.join(testRelation1, joinType = LeftSemi, + condition = Some(Symbol("x") === Symbol("d"))))) .analyze comparePlans(optimized, correctAnswer) } test("Union: LeftSemiAnti join no pushdown in self join scenario") { - val testRelation2 = LocalRelation('x.int, 'y.int, 'z.int) + val testRelation2 = LocalRelation(Symbol("x").int, Symbol("y").int, Symbol("z").int) val originalQuery = Union(Seq(testRelation, testRelation2)) - .join(testRelation2, joinType = LeftSemi, condition = Some('a === 'x)) + .join(testRelation2, joinType = LeftSemi, condition = Some(Symbol("a") === Symbol("x"))) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) @@ -254,12 +268,12 @@ class LeftSemiPushdownSuite extends PlanTest { val originalQuery = testRelation .select(star()) .repartition(1) - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) + .join(testRelation1, joinType = LeftSemi, condition = Some(Symbol("b") === Symbol("d"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) - .select('a, 'b, 'c) + .join(testRelation1, joinType = LeftSemi, condition = Some(Symbol("b") === Symbol("d"))) + .select(Symbol("a"), Symbol("b"), Symbol("c")) .repartition(1) .analyze comparePlans(optimized, correctAnswer) @@ -274,64 +288,72 @@ class LeftSemiPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .join(testRelation1, joinType = LeftSemi, condition = None) - .select('a, 'b, 'c) + .select(Symbol("a"), Symbol("b"), Symbol("c")) .repartition(1) .analyze comparePlans(optimized, correctAnswer) } test("Unary: LeftSemi join pushdown - partial pushdown") { - val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) + val testRelationWithArrayType = LocalRelation(Symbol("a").int, + Symbol("b").int, Symbol("c_arr").array(IntegerType)) val originalQuery = testRelationWithArrayType - .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'b === 'out_col)) + .generate(Explode(Symbol("c_arr")), alias = Some("arr"), outputNames = Seq("out_col")) + .join(testRelation1, joinType = LeftSemi, + condition = Some(Symbol("b") === Symbol("d") && Symbol("b") === Symbol("out_col"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelationWithArrayType - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd)) - .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) - .where('b === 'out_col) + .join(testRelation1, joinType = LeftSemi, condition = Some(Symbol("b") === Symbol("d"))) + .generate(Explode(Symbol("c_arr")), alias = Some("arr"), outputNames = Seq("out_col")) + .where(Symbol("b") === Symbol("out_col")) .analyze comparePlans(optimized, correctAnswer) } test("Unary: LeftAnti join pushdown - no pushdown") { - val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) + val testRelationWithArrayType = LocalRelation(Symbol("a").int, + Symbol("b").int, Symbol("c_arr").array(IntegerType)) val originalQuery = testRelationWithArrayType - .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) - .join(testRelation1, joinType = LeftAnti, condition = Some('b === 'd && 'b === 'out_col)) + .generate(Explode(Symbol("c_arr")), alias = Some("arr"), outputNames = Seq("out_col")) + .join(testRelation1, joinType = LeftAnti, + condition = Some(Symbol("b") === Symbol("d") && Symbol("b") === Symbol("out_col"))) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) } test("Unary: LeftSemiAnti join pushdown - no pushdown") { - val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType)) + val testRelationWithArrayType = LocalRelation(Symbol("a").int, + Symbol("b").int, Symbol("c_arr").array(IntegerType)) val originalQuery = testRelationWithArrayType - .generate(Explode('c_arr), alias = Some("arr"), outputNames = Seq("out_col")) - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'd === 'out_col)) + .generate(Explode(Symbol("c_arr")), alias = Some("arr"), outputNames = Seq("out_col")) + .join(testRelation1, joinType = LeftSemi, + condition = Some(Symbol("b") === Symbol("d") && Symbol("d") === Symbol("out_col"))) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) } test("Unary: LeftSemi join push down through Expand") { - val expand = Expand(Seq(Seq('a, 'b, "null"), Seq('a, "null", 'c)), - Seq('a, 'b, 'c), testRelation) - val originalQuery = expand - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'b === 1)) + val expand = Expand(Seq(Seq(Symbol("a"), Symbol("b"), "null"), Seq(Symbol("a"), "null", + Symbol("c"))), Seq(Symbol("a"), Symbol("b"), Symbol("c")), testRelation) + val originalQuery = expand.join(testRelation1, joinType = LeftSemi, + condition = Some(Symbol("b") === Symbol("d") && Symbol("b") === 1)) val optimized = Optimize.execute(originalQuery.analyze) - val correctAnswer = Expand(Seq(Seq('a, 'b, "null"), Seq('a, "null", 'c)), - Seq('a, 'b, 'c), testRelation - .join(testRelation1, joinType = LeftSemi, condition = Some('b === 'd && 'b === 1))) + val correctAnswer = Expand(Seq(Seq(Symbol("a"), + Symbol("b"), "null"), Seq(Symbol("a"), "null", Symbol("c"))), + Seq(Symbol("a"), Symbol("b"), Symbol("c")), testRelation + .join(testRelation1, joinType = LeftSemi, + condition = Some(Symbol("b") === Symbol("d") && Symbol("b") === 1))) .analyze comparePlans(optimized, correctAnswer) } - Seq(Some('d === 'e), None).foreach { case innerJoinCond => + Seq(Some(Symbol("d") === Symbol("e")), None).foreach { case innerJoinCond => Seq(LeftSemi, LeftAnti).foreach { case outerJT => Seq(Inner, LeftOuter, Cross, RightOuter).foreach { case innerJT => test(s"$outerJT pushdown empty join cond join type $innerJT join cond $innerJoinCond") { @@ -352,17 +374,19 @@ class LeftSemiPushdownSuite extends PlanTest { } } - Seq(Some('d === 'e), None).foreach { case innerJoinCond => + Seq(Some(Symbol("d") === Symbol("e")), None).foreach { case innerJoinCond => Seq(LeftSemi, LeftAnti).foreach { case outerJT => Seq(Inner, LeftOuter, Cross).foreach { case innerJT => test(s"$outerJT pushdown to left of join type: $innerJT join condition $innerJoinCond") { val joinedRelation = testRelation1.join(testRelation2, joinType = innerJT, innerJoinCond) val originalQuery = - joinedRelation.join(testRelation, joinType = outerJT, condition = Some('a === 'd)) + joinedRelation.join(testRelation, joinType = outerJT, + condition = Some(Symbol("a") === Symbol("d"))) val optimized = Optimize.execute(originalQuery.analyze) val pushedDownJoin = - testRelation1.join(testRelation, joinType = outerJT, condition = Some('a === 'd)) + testRelation1.join(testRelation, joinType = outerJT, + condition = Some(Symbol("a") === Symbol("d"))) val correctAnswer = pushedDownJoin.join(testRelation2, joinType = innerJT, innerJoinCond).analyze comparePlans(optimized, correctAnswer) @@ -371,17 +395,19 @@ class LeftSemiPushdownSuite extends PlanTest { } } - Seq(Some('e === 'd), None).foreach { case innerJoinCond => + Seq(Some(Symbol("e") === Symbol("d")), None).foreach { case innerJoinCond => Seq(LeftSemi, LeftAnti).foreach { case outerJT => Seq(Inner, RightOuter, Cross).foreach { case innerJT => test(s"$outerJT pushdown to right of join type: $innerJT join condition $innerJoinCond") { val joinedRelation = testRelation1.join(testRelation2, joinType = innerJT, innerJoinCond) val originalQuery = - joinedRelation.join(testRelation, joinType = outerJT, condition = Some('a === 'e)) + joinedRelation.join(testRelation, joinType = outerJT, + condition = Some(Symbol("a") === Symbol("e"))) val optimized = Optimize.execute(originalQuery.analyze) val pushedDownJoin = - testRelation2.join(testRelation, joinType = outerJT, condition = Some('a === 'e)) + testRelation2.join(testRelation, joinType = outerJT, + condition = Some(Symbol("a") === Symbol("e"))) val correctAnswer = testRelation1.join(pushedDownJoin, joinType = innerJT, innerJoinCond).analyze comparePlans(optimized, correctAnswer) @@ -394,7 +420,8 @@ class LeftSemiPushdownSuite extends PlanTest { test(s"$jt no pushdown - join condition refers left leg - join type for RightOuter") { val joinedRelation = testRelation1.join(testRelation2, joinType = RightOuter, None) val originalQuery = - joinedRelation.join(testRelation, joinType = jt, condition = Some('a === 'd)) + joinedRelation.join(testRelation, joinType = jt, + condition = Some(Symbol("a") === Symbol("d"))) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) } @@ -404,7 +431,8 @@ class LeftSemiPushdownSuite extends PlanTest { test(s"$jt no pushdown - join condition refers right leg - join type for LeftOuter") { val joinedRelation = testRelation1.join(testRelation2, joinType = LeftOuter, None) val originalQuery = - joinedRelation.join(testRelation, joinType = jt, condition = Some('a === 'e)) + joinedRelation.join(testRelation, joinType = jt, + condition = Some(Symbol("a") === Symbol("e"))) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) } @@ -415,7 +443,8 @@ class LeftSemiPushdownSuite extends PlanTest { test(s"$outerJT no pushdown - join condition refers both leg - join type $innerJT") { val joinedRelation = testRelation1.join(testRelation2, joinType = innerJT, None) val originalQuery = joinedRelation - .join(testRelation, joinType = outerJT, condition = Some('a === 'd && 'a === 'e)) + .join(testRelation, joinType = outerJT, + condition = Some(Symbol("a") === Symbol("d") && Symbol("a") === Symbol("e"))) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) } @@ -427,7 +456,8 @@ class LeftSemiPushdownSuite extends PlanTest { test(s"$outerJT no pushdown - join condition refers none of the leg - join type $innerJT") { val joinedRelation = testRelation1.join(testRelation2, joinType = innerJT, None) val originalQuery = joinedRelation - .join(testRelation, joinType = outerJT, condition = Some('d + 'e === 'a)) + .join(testRelation, joinType = outerJT, + condition = Some(Symbol("d") + Symbol("e") === Symbol("a"))) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) } @@ -438,7 +468,8 @@ class LeftSemiPushdownSuite extends PlanTest { test(s"$jt no pushdown when child join type is FullOuter") { val joinedRelation = testRelation1.join(testRelation2, joinType = FullOuter, None) val originalQuery = - joinedRelation.join(testRelation, joinType = jt, condition = Some('a === 'e)) + joinedRelation.join(testRelation, joinType = jt, + condition = Some(Symbol("a") === Symbol("e"))) val optimized = Optimize.execute(originalQuery.analyze) comparePlans(optimized, originalQuery.analyze) } @@ -449,14 +480,14 @@ class LeftSemiPushdownSuite extends PlanTest { Seq(-1, 100000).foreach { threshold => withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> threshold.toString) { val originalQuery = testRelation - .groupBy('b)('b) - .join(testRelation1, joinType = jt, condition = Some('b <=> 'd)) + .groupBy(Symbol("b"))(Symbol("b")) + .join(testRelation1, joinType = jt, condition = Some(Symbol("b") <=> Symbol("d"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = if (threshold > 0) { testRelation - .join(testRelation1, joinType = jt, condition = Some('b <=> 'd)) - .groupBy('b)('b) + .join(testRelation1, joinType = jt, condition = Some(Symbol("b") <=> Symbol("d"))) + .groupBy(Symbol("b"))(Symbol("b")) .analyze } else { originalQuery.analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index c06c92f9c151..71b9d5a178e5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -33,16 +33,16 @@ class LikeSimplificationSuite extends PlanTest { LikeSimplification) :: Nil } - val testRelation = LocalRelation('a.string) + val testRelation = LocalRelation(Symbol("a").string) test("simplify Like into StartsWith") { val originalQuery = testRelation - .where(('a like "abc%") || ('a like "abc\\%")) + .where((Symbol("a") like "abc%") || (Symbol("a") like "abc\\%")) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(StartsWith('a, "abc") || ('a like "abc\\%")) + .where(StartsWith(Symbol("a"), "abc") || (Symbol("a") like "abc\\%")) .analyze comparePlans(optimized, correctAnswer) @@ -51,11 +51,11 @@ class LikeSimplificationSuite extends PlanTest { test("simplify Like into EndsWith") { val originalQuery = testRelation - .where('a like "%xyz") + .where(Symbol("a") like "%xyz") val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(EndsWith('a, "xyz")) + .where(EndsWith(Symbol("a"), "xyz")) .analyze comparePlans(optimized, correctAnswer) @@ -64,12 +64,12 @@ class LikeSimplificationSuite extends PlanTest { test("simplify Like into startsWith and EndsWith") { val originalQuery = testRelation - .where(('a like "abc\\%def") || ('a like "abc%def")) + .where((Symbol("a") like "abc\\%def") || (Symbol("a") like "abc%def")) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(('a like "abc\\%def") || - (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def")))) + .where((Symbol("a") like "abc\\%def") || (Length(Symbol("a")) >= 6 && + (StartsWith(Symbol("a"), "abc") && EndsWith(Symbol("a"), "def")))) .analyze comparePlans(optimized, correctAnswer) @@ -78,11 +78,11 @@ class LikeSimplificationSuite extends PlanTest { test("simplify Like into Contains") { val originalQuery = testRelation - .where(('a like "%mn%") || ('a like "%mn\\%")) + .where((Symbol("a") like "%mn%") || (Symbol("a") like "%mn\\%")) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(Contains('a, "mn") || ('a like "%mn\\%")) + .where(Contains(Symbol("a"), "mn") || (Symbol("a") like "%mn\\%")) .analyze comparePlans(optimized, correctAnswer) @@ -91,28 +91,28 @@ class LikeSimplificationSuite extends PlanTest { test("simplify Like into EqualTo") { val originalQuery = testRelation - .where(('a like "") || ('a like "abc")) + .where((Symbol("a") like "") || (Symbol("a") like "abc")) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where(('a === "") || ('a === "abc")) + .where((Symbol("a") === "") || (Symbol("a") === "abc")) .analyze comparePlans(optimized, correctAnswer) } test("null pattern") { - val originalQuery = testRelation.where('a like Literal(null, StringType)).analyze + val originalQuery = testRelation.where(Symbol("a") like Literal(null, StringType)).analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, testRelation.where(Literal(null, BooleanType)).analyze) } test("test like escape syntax") { - val originalQuery1 = testRelation.where('a.like("abc#%", '#')) + val originalQuery1 = testRelation.where(Symbol("a").like("abc#%", '#')) val optimized1 = Optimize.execute(originalQuery1.analyze) comparePlans(optimized1, originalQuery1.analyze) - val originalQuery2 = testRelation.where('a.like("abc#%abc", '#')) + val originalQuery2 = testRelation.where(Symbol("a").like("abc#%abc", '#')) val optimized2 = Optimize.execute(originalQuery2.analyze) comparePlans(optimized2, originalQuery2.analyze) } @@ -120,47 +120,47 @@ class LikeSimplificationSuite extends PlanTest { test("SPARK-33677: LikeSimplification should be skipped if pattern contains any escapeChar") { val originalQuery1 = testRelation - .where(('a like "abc%") || ('a like "\\abc%")) + .where((Symbol("a") like "abc%") || (Symbol("a") like "\\abc%")) val optimized1 = Optimize.execute(originalQuery1.analyze) val correctAnswer1 = testRelation - .where(StartsWith('a, "abc") || ('a like "\\abc%")) + .where(StartsWith(Symbol("a"), "abc") || (Symbol("a") like "\\abc%")) .analyze comparePlans(optimized1, correctAnswer1) val originalQuery2 = testRelation - .where(('a like "%xyz") || ('a like "%xyz\\")) + .where((Symbol("a") like "%xyz") || (Symbol("a") like "%xyz\\")) val optimized2 = Optimize.execute(originalQuery2.analyze) val correctAnswer2 = testRelation - .where(EndsWith('a, "xyz") || ('a like "%xyz\\")) + .where(EndsWith(Symbol("a"), "xyz") || (Symbol("a") like "%xyz\\")) .analyze comparePlans(optimized2, correctAnswer2) val originalQuery3 = testRelation - .where(('a like ("@bc%def", '@')) || ('a like "abc%def")) + .where((Symbol("a") like ("@bc%def", '@')) || (Symbol("a") like "abc%def")) val optimized3 = Optimize.execute(originalQuery3.analyze) val correctAnswer3 = testRelation - .where(('a like ("@bc%def", '@')) || - (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def")))) + .where((Symbol("a") like ("@bc%def", '@')) || (Length(Symbol("a")) >= 6 && + (StartsWith(Symbol("a"), "abc") && EndsWith(Symbol("a"), "def")))) .analyze comparePlans(optimized3, correctAnswer3) val originalQuery4 = testRelation - .where(('a like "%mn%") || ('a like ("%mn%", '%'))) + .where((Symbol("a") like "%mn%") || (Symbol("a") like ("%mn%", '%'))) val optimized4 = Optimize.execute(originalQuery4.analyze) val correctAnswer4 = testRelation - .where(Contains('a, "mn") || ('a like ("%mn%", '%'))) + .where(Contains(Symbol("a"), "mn") || (Symbol("a") like ("%mn%", '%'))) .analyze comparePlans(optimized4, correctAnswer4) val originalQuery5 = testRelation - .where(('a like "abc") || ('a like ("abbc", 'b'))) + .where((Symbol("a") like "abc") || (Symbol("a") like ("abbc", 'b'))) val optimized5 = Optimize.execute(originalQuery5.analyze) val correctAnswer5 = testRelation - .where(('a === "abc") || ('a like ("abbc", 'b'))) + .where((Symbol("a") === "abc") || (Symbol("a") like ("abbc", 'b'))) .analyze comparePlans(optimized5, correctAnswer5) } @@ -168,15 +168,16 @@ class LikeSimplificationSuite extends PlanTest { test("simplify LikeAll") { val originalQuery = testRelation - .where(('a likeAll( + .where((Symbol("a") likeAll( "abc%", "abc\\%", "%xyz", "abc\\%def", "abc%def", "%mn%", "%mn\\%", "", "abc"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where((((((StartsWith('a, "abc") && EndsWith('a, "xyz")) && - (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def")))) && - Contains('a, "mn")) && ('a === "")) && ('a === "abc")) && - ('a likeAll("abc\\%", "abc\\%def", "%mn\\%"))) + .where((((((StartsWith(Symbol("a"), "abc") && EndsWith(Symbol("a"), "xyz")) && + (Length(Symbol("a")) >= 6 && (StartsWith(Symbol("a"), "abc") && + EndsWith(Symbol("a"), "def")))) && Contains(Symbol("a"), "mn")) && + (Symbol("a") === "")) && (Symbol("a") === "abc")) && + (Symbol("a") likeAll("abc\\%", "abc\\%def", "%mn\\%"))) .analyze comparePlans(optimized, correctAnswer) @@ -185,15 +186,16 @@ class LikeSimplificationSuite extends PlanTest { test("simplify NotLikeAll") { val originalQuery = testRelation - .where(('a notLikeAll( + .where((Symbol("a") notLikeAll( "abc%", "abc\\%", "%xyz", "abc\\%def", "abc%def", "%mn%", "%mn\\%", "", "abc"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where((((((Not(StartsWith('a, "abc")) && Not(EndsWith('a, "xyz"))) && - Not(Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def")))) && - Not(Contains('a, "mn"))) && Not('a === "")) && Not('a === "abc")) && - ('a notLikeAll("abc\\%", "abc\\%def", "%mn\\%"))) + .where((((((Not(StartsWith(Symbol("a"), "abc")) && Not(EndsWith(Symbol("a"), "xyz"))) && + Not(Length(Symbol("a")) >= 6 && (StartsWith(Symbol("a"), "abc") && + EndsWith(Symbol("a"), "def")))) && Not(Contains(Symbol("a"), "mn"))) && + Not(Symbol("a") === "")) && Not(Symbol("a") === "abc")) && + (Symbol("a") notLikeAll("abc\\%", "abc\\%def", "%mn\\%"))) .analyze comparePlans(optimized, correctAnswer) @@ -202,15 +204,16 @@ class LikeSimplificationSuite extends PlanTest { test("simplify LikeAny") { val originalQuery = testRelation - .where(('a likeAny( + .where((Symbol("a") likeAny( "abc%", "abc\\%", "%xyz", "abc\\%def", "abc%def", "%mn%", "%mn\\%", "", "abc"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where((((((StartsWith('a, "abc") || EndsWith('a, "xyz")) || - (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def")))) || - Contains('a, "mn")) || ('a === "")) || ('a === "abc")) || - ('a likeAny("abc\\%", "abc\\%def", "%mn\\%"))) + .where((((((StartsWith(Symbol("a"), "abc") || EndsWith(Symbol("a"), "xyz")) || + (Length(Symbol("a")) >= 6 && (StartsWith(Symbol("a"), "abc") && + EndsWith(Symbol("a"), "def")))) || + Contains(Symbol("a"), "mn")) || (Symbol("a") === "")) || (Symbol("a") === "abc")) || + (Symbol("a") likeAny("abc\\%", "abc\\%def", "%mn\\%"))) .analyze comparePlans(optimized, correctAnswer) @@ -219,15 +222,17 @@ class LikeSimplificationSuite extends PlanTest { test("simplify NotLikeAny") { val originalQuery = testRelation - .where(('a notLikeAny( + .where((Symbol("a") notLikeAny( "abc%", "abc\\%", "%xyz", "abc\\%def", "abc%def", "%mn%", "%mn\\%", "", "abc"))) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .where((((((Not(StartsWith('a, "abc")) || Not(EndsWith('a, "xyz"))) || - Not(Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def")))) || - Not(Contains('a, "mn"))) || Not('a === "")) || Not('a === "abc")) || - ('a notLikeAny("abc\\%", "abc\\%def", "%mn\\%"))) + .where((((((Not(StartsWith(Symbol("a"), "abc")) || Not(EndsWith(Symbol("a"), "xyz"))) || + Not(Length(Symbol("a")) >= 6 && (StartsWith(Symbol("a"), "abc") && + EndsWith(Symbol("a"), "def")))) || + Not(Contains(Symbol("a"), "mn"))) || Not(Symbol("a") === "")) || + Not(Symbol("a") === "abc")) || + (Symbol("a") notLikeAny("abc\\%", "abc\\%def", "%mn\\%"))) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala index bb23b63c03ce..fdd8ae781e96 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LimitPushdownSuite.scala @@ -45,8 +45,8 @@ class LimitPushdownSuite extends PlanTest { private val testRelation2 = LocalRelation.fromExternalRows( Seq("d".attr.int, "e".attr.int, "f".attr.int), 1.to(6).map(_ => Row(1, 2, 3))) - private val x = testRelation.subquery('x) - private val y = testRelation.subquery('y) + private val x = testRelation.subquery(Symbol("x")) + private val y = testRelation.subquery(Symbol("y")) // Union --------------------------------------------------------------------------------------- @@ -76,20 +76,24 @@ class LimitPushdownSuite extends PlanTest { test("Union: no limit to both sides if children having smaller limit values") { val unionQuery = - Union(testRelation.limit(1), testRelation2.select('d, 'e, 'f).limit(1)).limit(2) + Union(testRelation.limit(1), + testRelation2.select(Symbol("d"), Symbol("e"), Symbol("f")).limit(1)).limit(2) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Union(testRelation.limit(1), testRelation2.select('d, 'e, 'f).limit(1)).analyze + Union(testRelation.limit(1), + testRelation2.select(Symbol("d"), Symbol("e"), Symbol("f")).limit(1)).analyze comparePlans(unionOptimized, unionCorrectAnswer) } test("Union: limit to each sides if children having larger limit values") { val unionQuery = - Union(testRelation.limit(3), testRelation2.select('d, 'e, 'f).limit(4)).limit(2) + Union(testRelation.limit(3), + testRelation2.select(Symbol("d"), Symbol("e"), Symbol("f")).limit(4)).limit(2) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = Limit(2, Union( - LocalLimit(2, testRelation), LocalLimit(2, testRelation2.select('d, 'e, 'f)))).analyze + LocalLimit(2, testRelation), + LocalLimit(2, testRelation2.select(Symbol("d"), Symbol("e"), Symbol("f"))))).analyze comparePlans(unionOptimized, unionCorrectAnswer) } @@ -153,7 +157,7 @@ class LimitPushdownSuite extends PlanTest { } test("full outer join where neither side is limited and left side has larger statistics") { - val xBig = testRelation.copy(data = Seq.fill(10)(null)).subquery('x) + val xBig = testRelation.copy(data = Seq.fill(10)(null)).subquery(Symbol("x")) assert(xBig.stats.sizeInBytes > y.stats.sizeInBytes) val originalQuery = xBig.join(y, FullOuter).limit(1).analyze val optimized = Optimize.execute(originalQuery) @@ -162,7 +166,7 @@ class LimitPushdownSuite extends PlanTest { } test("full outer join where neither side is limited and right side has larger statistics") { - val yBig = testRelation.copy(data = Seq.fill(10)(null)).subquery('y) + val yBig = testRelation.copy(data = Seq.fill(10)(null)).subquery(Symbol("y")) assert(x.stats.sizeInBytes < yBig.stats.sizeInBytes) val originalQuery = x.join(yBig, FullOuter).limit(1).analyze val optimized = Optimize.execute(originalQuery) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala index c83ab375ee15..1f1afd1edc9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala @@ -41,16 +41,16 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { private val name = StructType.fromDDL("first string, middle string, last string") private val employer = StructType.fromDDL("id int, company struct") private val contact = LocalRelation( - 'id.int, - 'name.struct(name), - 'address.string, - 'friends.array(name), - 'relatives.map(StringType, name), - 'employer.struct(employer)) + Symbol("id").int, + Symbol("name").struct(name), + Symbol("address").string, + Symbol("friends").array(name), + Symbol("relatives").map(StringType, name), + Symbol("employer").struct(employer)) test("Pushing a single nested field projection") { def testSingleFieldPushDown(op: LogicalPlan => LogicalPlan): Unit = { - val middle = GetStructField('name, 1, Some("middle")) + val middle = GetStructField(Symbol("name"), 1, Some("middle")) val query = op(contact).select(middle).analyze val optimized = Optimize.execute(query) val expected = op(contact.select(middle)).analyze @@ -63,18 +63,18 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { } test("Pushing multiple nested field projection") { - val first = GetStructField('name, 0, Some("first")) - val last = GetStructField('name, 2, Some("last")) + val first = GetStructField(Symbol("name"), 0, Some("first")) + val last = GetStructField(Symbol("name"), 2, Some("last")) val query = contact .limit(5) - .select('id, first, last) + .select(Symbol("id"), first, last) .analyze val optimized = Optimize.execute(query) val expected = contact - .select('id, first, last) + .select(Symbol("id"), first, last) .limit(5) .analyze @@ -82,12 +82,12 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { } test("function with nested field inputs") { - val first = GetStructField('name, 0, Some("first")) - val last = GetStructField('name, 2, Some("last")) + val first = GetStructField(Symbol("name"), 0, Some("first")) + val last = GetStructField(Symbol("name"), 2, Some("last")) val query = contact .limit(5) - .select('id, ConcatWs(Seq(first, last))) + .select(Symbol("id"), ConcatWs(Seq(first, last))) .analyze val optimized = Optimize.execute(query) @@ -95,18 +95,19 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { val aliases = collectGeneratedAliases(optimized) val expected = contact - .select('id, first.as(aliases(0)), last.as(aliases(1))) + .select(Symbol("id"), first.as(aliases(0)), last.as(aliases(1))) .limit(5) .select( - 'id, + Symbol("id"), ConcatWs(Seq($"${aliases(0)}", $"${aliases(1)}")).as("concat_ws(name.first, name.last)")) .analyze comparePlans(optimized, expected) } test("multi-level nested field") { - val field1 = GetStructField(GetStructField('employer, 1, Some("company")), 0, Some("name")) - val field2 = GetStructField('employer, 0, Some("id")) + val field1 = + GetStructField(GetStructField(Symbol("employer"), 1, Some("company")), 0, Some("name")) + val field2 = GetStructField(Symbol("employer"), 0, Some("id")) val query = contact .limit(5) @@ -123,18 +124,18 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { } test("Push original case-sensitive names") { - val first1 = GetStructField('name, 0, Some("first")) - val first2 = GetStructField('name, 1, Some("FIRST")) + val first1 = GetStructField(Symbol("name"), 0, Some("first")) + val first2 = GetStructField(Symbol("name"), 1, Some("FIRST")) val query = contact .limit(5) - .select('id, first1, first2) + .select(Symbol("id"), first1, first2) .analyze val optimized = Optimize.execute(query) val expected = contact - .select('id, first1, first2) + .select(Symbol("id"), first1, first2) .limit(5) .analyze @@ -143,15 +144,15 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { test("Pushing a single nested field projection - negative") { val ops = Seq( - (input: LogicalPlan) => input.distribute('name)(1), - (input: LogicalPlan) => input.orderBy('name.asc), - (input: LogicalPlan) => input.sortBy('name.asc), + (input: LogicalPlan) => input.distribute(Symbol("name"))(1), + (input: LogicalPlan) => input.orderBy(Symbol("name").asc), + (input: LogicalPlan) => input.sortBy(Symbol("name").asc), (input: LogicalPlan) => input.union(input) ) val queries = ops.map { op => - op(contact.select('name)) - .select(GetStructField('name, 1, Some("middle"))) + op(contact.select(Symbol("name"))) + .select(GetStructField(Symbol("name"), 1, Some("middle"))) .analyze } @@ -161,20 +162,20 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { comparePlans(optimized, expected) } val expectedUnion = - contact.select('name).union(contact.select('name.as('name))) - .select(GetStructField('name, 1, Some("middle"))).analyze + contact.select(Symbol("name")).union(contact.select(Symbol("name").as(Symbol("name")))) + .select(GetStructField(Symbol("name"), 1, Some("middle"))).analyze comparePlans(optimizedUnion, expectedUnion) } test("Pushing a single nested field projection through filters - negative") { val ops = Array( - (input: LogicalPlan) => input.where('name.isNotNull), + (input: LogicalPlan) => input.where(Symbol("name").isNotNull), (input: LogicalPlan) => input.where($"name.middle".isNotNull) ) val queries = ops.map { op => op(contact) - .select(GetStructField('name, 1, Some("middle"))) + .select(GetStructField(Symbol("name"), 1, Some("middle"))) .analyze } @@ -189,25 +190,26 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { test("Do not optimize when parent field is used") { val query = contact .limit(5) - .select('id, GetStructField('name, 0, Some("first")), 'name) + .select(Symbol("id"), GetStructField(Symbol("name"), 0, Some("first")), Symbol("name")) .analyze val optimized = Optimize.execute(query) val expected = contact - .select('id, 'name) + .select(Symbol("id"), Symbol("name")) .limit(5) - .select('id, GetStructField('name, 0, Some("first")), 'name) + .select(Symbol("id"), GetStructField(Symbol("name"), 0, Some("first")), Symbol("name")) .analyze comparePlans(optimized, expected) } test("Some nested column means the whole structure") { - val nestedRelation = LocalRelation('a.struct('b.struct('c.int, 'd.int, 'e.int))) + val nestedRelation = LocalRelation( + Symbol("a").struct(Symbol("b").struct(Symbol("c").int, Symbol("d").int, Symbol("e").int))) val query = nestedRelation .limit(5) - .select(GetStructField('a, 0, Some("b"))) + .select(GetStructField(Symbol("a"), 0, Some("b"))) .analyze val optimized = Optimize.execute(query) @@ -216,12 +218,12 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { } test("nested field pruning for getting struct field in array of struct") { - val field1 = GetArrayStructFields(child = 'friends, + val field1 = GetArrayStructFields(child = Symbol("friends"), field = StructField("first", StringType), ordinal = 0, numFields = 3, containsNull = true) - val field2 = GetStructField('employer, 0, Some("id")) + val field2 = GetStructField(Symbol("employer"), 0, Some("id")) val query = contact .limit(5) @@ -238,8 +240,8 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { } test("nested field pruning for getting struct field in map") { - val field1 = GetStructField(GetMapValue('relatives, Literal("key")), 0, Some("first")) - val field2 = GetArrayStructFields(child = MapValues('relatives), + val field1 = GetStructField(GetMapValue(Symbol("relatives"), Literal("key")), 0, Some("first")) + val field2 = GetArrayStructFields(child = MapValues(Symbol("relatives")), field = StructField("middle", StringType), ordinal = 1, numFields = 3, @@ -260,15 +262,15 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { } test("SPARK-27633: Do not generate redundant aliases if parent nested field is aliased too") { - val nestedRelation = LocalRelation('a.struct('b.struct('c.int, - 'd.struct('f.int, 'g.int)), 'e.int)) + val nestedRelation = LocalRelation(Symbol("a").struct(Symbol("b").struct(Symbol("c").int, + Symbol("d").struct(Symbol("f").int, Symbol("g").int)), Symbol("e").int)) // `a.b` - val first = 'a.getField("b") + val first = Symbol("a").getField("b") // `a.b.c` + 1 - val second = 'a.getField("b").getField("c") + Literal(1) + val second = Symbol("a").getField("b").getField("c") + Literal(1) // `a.b.d.f` - val last = 'a.getField("b").getField("d").getField("f") + val last = Symbol("a").getField("b").getField("d").getField("f") val query = nestedRelation .limit(5) @@ -292,8 +294,8 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { test("Nested field pruning for Project and Generate") { val query = contact - .generate(Explode('friends.getField("first")), outputNames = Seq("explode")) - .select('explode, 'friends.getField("middle")) + .generate(Explode(Symbol("friends").getField("first")), outputNames = Seq("explode")) + .select(Symbol("explode"), Symbol("friends").getField("middle")) .analyze val optimized = Optimize.execute(query) @@ -301,27 +303,27 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { val expected = contact .select( - 'friends.getField("middle").as(aliases(0)), - 'friends.getField("first").as(aliases(1))) + Symbol("friends").getField("middle").as(aliases(0)), + Symbol("friends").getField("first").as(aliases(1))) .generate(Explode($"${aliases(1)}"), unrequiredChildIndex = Seq(1), outputNames = Seq("explode")) - .select('explode, $"${aliases(0)}".as("friends.middle")) + .select(Symbol("explode"), $"${aliases(0)}".as("friends.middle")) .analyze comparePlans(optimized, expected) } test("Nested field pruning for Generate") { val query = contact - .generate(Explode('friends.getField("first")), outputNames = Seq("explode")) - .select('explode) + .generate(Explode(Symbol("friends").getField("first")), outputNames = Seq("explode")) + .select(Symbol("explode")) .analyze val optimized = Optimize.execute(query) val aliases = collectGeneratedAliases(optimized) val expected = contact - .select('friends.getField("first").as(aliases(0))) + .select(Symbol("friends").getField("first").as(aliases(0))) .generate(Explode($"${aliases(0)}"), unrequiredChildIndex = Seq(0), outputNames = Seq("explode")) @@ -331,23 +333,23 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { test("Nested field pruning for Project and Generate: not prune on generator output") { val companies = LocalRelation( - 'id.int, - 'employers.array(employer)) + Symbol("id").int, + Symbol("employers").array(employer)) val query = companies - .generate(Explode('employers.getField("company")), outputNames = Seq("company")) - .select('company.getField("name")) + .generate(Explode(Symbol("employers").getField("company")), outputNames = Seq("company")) + .select(Symbol("company").getField("name")) .analyze val optimized = Optimize.execute(query) val aliases = collectGeneratedAliases(optimized) val expected = companies - .select('employers.getField("company").as(aliases(0))) + .select(Symbol("employers").getField("company").as(aliases(0))) .generate(Explode($"${aliases(0)}"), unrequiredChildIndex = Seq(0), outputNames = Seq("company")) - .select('company.getField("name").as("company.name")) + .select(Symbol("company").getField("name").as("company.name")) .analyze comparePlans(optimized, expected) } @@ -355,17 +357,17 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { test("Nested field pruning for Generate: not prune on required child output") { val query = contact .generate( - Explode('friends.getField("first")), + Explode(Symbol("friends").getField("first")), outputNames = Seq("explode")) - .select('explode, 'friends) + .select(Symbol("explode"), Symbol("friends")) .analyze val optimized = Optimize.execute(query) val expected = contact - .select('friends) - .generate(Explode('friends.getField("first")), + .select(Symbol("friends")) + .generate(Explode(Symbol("friends").getField("first")), outputNames = Seq("explode")) - .select('explode, 'friends) + .select(Symbol("explode"), Symbol("friends")) .analyze comparePlans(optimized, expected) } @@ -380,7 +382,7 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { val aliases1 = collectGeneratedAliases(optimized1) val expected1 = contact - .select('id, 'name.getField("middle").as(aliases1(0))) + .select(Symbol("id"), Symbol("name").getField("middle").as(aliases1(0))) .distribute($"id")(1) .select($"${aliases1(0)}".as("middle")) .analyze @@ -395,7 +397,7 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { val aliases2 = collectGeneratedAliases(optimized2) val expected2 = contact - .select('name.getField("middle").as(aliases2(0))) + .select(Symbol("name").getField("middle").as(aliases2(0))) .distribute($"${aliases2(0)}")(1) .select($"${aliases2(0)}".as("middle")) .analyze @@ -413,8 +415,8 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { test("Nested field pruning through Join") { val department = LocalRelation( - 'depID.int, - 'personID.string) + Symbol("depID").int, + Symbol("personID").string) val query1 = contact.join(department, condition = Some($"id" === $"depID")) .select($"name.middle") @@ -423,8 +425,8 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { val aliases1 = collectGeneratedAliases(optimized1) - val expected1 = contact.select('id, 'name.getField("middle").as(aliases1(0))) - .join(department.select('depID), condition = Some($"id" === $"depID")) + val expected1 = contact.select(Symbol("id"), Symbol("name").getField("middle").as(aliases1(0))) + .join(department.select(Symbol("depID")), condition = Some($"id" === $"depID")) .select($"${aliases1(0)}".as("middle")) .analyze comparePlans(optimized1, expected1) @@ -437,15 +439,16 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { val aliases2 = collectGeneratedAliases(optimized2) val expected2 = contact.select( - 'name.getField("first").as(aliases2(0)), - 'name.getField("middle").as(aliases2(1))) - .join(department.select('personID), condition = Some($"${aliases2(1)}" === $"personID")) + Symbol("name").getField("first").as(aliases2(0)), + Symbol("name").getField("middle").as(aliases2(1))) + .join(department.select(Symbol("personID")), + condition = Some($"${aliases2(1)}" === $"personID")) .select($"${aliases2(0)}".as("first")) .analyze comparePlans(optimized2, expected2) - val contact2 = LocalRelation('name2.struct(name)) - val query3 = contact.select('name) + val contact2 = LocalRelation(Symbol("name2").struct(name)) + val query3 = contact.select(Symbol("name")) .join(contact2, condition = Some($"name" === $"name2")) .select($"name.first") .analyze @@ -461,7 +464,7 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { val expected1 = basePlan( contact - .select($"id", 'name.getField("first").as(aliases1(0))) + .select($"id", Symbol("name").getField("first").as(aliases1(0))) ).groupBy($"id")(first($"${aliases1(0)}").as("first")).analyze comparePlans(optimized1, expected1) @@ -471,7 +474,8 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { val expected2 = basePlan( contact - .select('name.getField("last").as(aliases2(0)), 'name.getField("first").as(aliases2(1))) + .select(Symbol("name").getField("last").as(aliases2(0)), + Symbol("name").getField("first").as(aliases2(1))) ).groupBy($"${aliases2(0)}")(first($"${aliases2(1)}").as("first")).analyze comparePlans(optimized2, expected2) } @@ -495,7 +499,7 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { val spec = windowSpec($"address" :: Nil, $"id".asc :: Nil, UnspecifiedFrame) val winExpr = windowExpr(RowNumber(), spec) val query = contact - .select($"name.first", winExpr.as('window)) + .select($"name.first", winExpr.as(Symbol("window"))) .orderBy($"name.last".asc) .analyze val optimized = Optimize.execute(query) @@ -513,7 +517,7 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { test("Nested field pruning for Filter with other supported operators") { val spec = windowSpec($"address" :: Nil, $"id".asc :: Nil, UnspecifiedFrame) val winExpr = windowExpr(RowNumber(), spec) - val query1 = contact.select($"name.first", winExpr.as('window)) + val query1 = contact.select($"name.first", winExpr.as(Symbol("window"))) .where($"window" === 1 && $"name.first" === "a") .analyze val optimized1 = Optimize.execute(query1) @@ -558,8 +562,8 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { comparePlans(optimized3, expected3) val department = LocalRelation( - 'depID.int, - 'personID.string) + Symbol("depID").int, + Symbol("personID").string) val query4 = contact.join(department, condition = Some($"id" === $"depID")) .where($"name.first" === "a") .select($"name.first") @@ -568,7 +572,7 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { val aliases4 = collectGeneratedAliases(optimized4) val expected4 = contact .select($"id", $"name.first".as(aliases4(1))) - .join(department.select('depID), condition = Some($"id" === $"depID")) + .join(department.select(Symbol("depID")), condition = Some($"id" === $"depID")) .select($"${aliases4(1)}".as(aliases4(0))) .where($"${aliases4(0)}" === "a") .select($"${aliases4(0)}".as("first")) @@ -637,7 +641,7 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { Seq(ConcatWs(Seq($"name.first", $"name.middle")), ConcatWs(Seq($"name.middle", $"name.first"))) ), - Seq('a.string, 'b.string), + Seq(Symbol("a").string, Symbol("b").string), basePlan(contact) ).analyze val optimized1 = Optimize.execute(query1) @@ -649,10 +653,10 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { Seq(ConcatWs(Seq($"${aliases1(0)}", $"${aliases1(1)}")), ConcatWs(Seq($"${aliases1(1)}", $"${aliases1(0)}"))) ), - Seq('a.string, 'b.string), + Seq(Symbol("a").string, Symbol("b").string), basePlan(contact.select( - 'name.getField("first").as(aliases1(0)), - 'name.getField("middle").as(aliases1(1)))) + Symbol("name").getField("first").as(aliases1(0)), + Symbol("name").getField("middle").as(aliases1(1)))) ).analyze comparePlans(optimized1, expected1) } @@ -670,7 +674,7 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { Seq($"name", $"name.middle"), Seq($"name", ConcatWs(Seq($"name.middle", $"name.first"))) ), - Seq('a.string, 'b.string), + Seq(Symbol("a").string, Symbol("b").string), contact ).analyze val optimized2 = Optimize.execute(query2) @@ -679,7 +683,7 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { Seq($"name", $"name.middle"), Seq($"name", ConcatWs(Seq($"name.middle", $"name.first"))) ), - Seq('a.string, 'b.string), + Seq(Symbol("a").string, Symbol("b").string), contact.select($"name") ).analyze comparePlans(optimized2, expected2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala index bb9919f94eef..d9df144057aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala @@ -30,9 +30,9 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest { val batches = Batch("NormalizeFloatingPointNumbers", Once, NormalizeFloatingNumbers) :: Nil } - val testRelation1 = LocalRelation('a.double) + val testRelation1 = LocalRelation(Symbol("a").double) val a = testRelation1.output(0) - val testRelation2 = LocalRelation('a.double) + val testRelation2 = LocalRelation(Symbol("a").double) val b = testRelation2.output(0) test("normalize floating points in window function expressions") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala index 6d7c4c3c7e9d..d8d44fcb4cb7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala @@ -73,18 +73,20 @@ class ObjectSerializerPruningSuite extends PlanTest { } test("SPARK-26619: Prune the unused serializers from SerializeFromObject") { - val testRelation = LocalRelation('_1.int, '_2.int) + val testRelation = LocalRelation(Symbol("_1").int, Symbol("_2").int) val serializerObject = CatalystSerde.serialize[(Int, Int)]( CatalystSerde.deserialize[(Int, Int)](testRelation)) - val query = serializerObject.select('_1) + val query = serializerObject.select(Symbol("_1")) val optimized = Optimize.execute(query.analyze) - val expected = serializerObject.copy(serializer = Seq(serializerObject.serializer.head)).analyze + val expected = serializerObject.copy(serializer = + Seq(serializerObject.serializer.head)).analyze comparePlans(optimized, expected) } test("Prune nested serializers") { withSQLConf(SQLConf.SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") { - val testRelation = LocalRelation('_1.struct(StructType.fromDDL("_1 int, _2 string")), '_2.int) + val testRelation = LocalRelation( + Symbol("_1").struct(StructType.fromDDL("_1 int, _2 string")), Symbol("_2").int) val serializerObject = CatalystSerde.serialize[((Int, String), Int)]( CatalystSerde.deserialize[((Int, String), Int)](testRelation)) val query = serializerObject.select($"_1._1") @@ -111,7 +113,8 @@ class ObjectSerializerPruningSuite extends PlanTest { test("SPARK-32652: Prune nested serializers: RowEncoder") { withSQLConf(SQLConf.SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") { - val testRelation = LocalRelation('i.struct(StructType.fromDDL("a int, b string")), 'j.int) + val testRelation = LocalRelation( + Symbol("i").struct(StructType.fromDDL("a int, b string")), Symbol("j").int) val rowEncoder = RowEncoder(new StructType() .add("i", new StructType().add("a", "int").add("b", "string")) .add("j", "int")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvExprsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvExprsSuite.scala index 9b208cf2b57c..4c1a57108cda 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvExprsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvExprsSuite.scala @@ -43,29 +43,31 @@ class OptimizeCsvExprsSuite extends PlanTest with ExpressionEvalHelper { val schema = StructType.fromDDL("a int, b int") - private val csvAttr = 'csv.string + private val csvAttr = Symbol("csv").string private val testRelation = LocalRelation(csvAttr) test("SPARK-32968: prune unnecessary columns from GetStructField + from_csv") { val options = Map.empty[String, String] val query1 = testRelation - .select(GetStructField(CsvToStructs(schema, options, 'csv), 0)) + .select(GetStructField(CsvToStructs(schema, options, Symbol("csv")), 0)) val optimized1 = Optimizer.execute(query1.analyze) val prunedSchema1 = StructType.fromDDL("a int") val expected1 = testRelation - .select(GetStructField(CsvToStructs(schema, options, 'csv, None, Some(prunedSchema1)), 0)) + .select(GetStructField( + CsvToStructs(schema, options, Symbol("csv"), None, Some(prunedSchema1)), 0)) .analyze comparePlans(optimized1, expected1) val query2 = testRelation - .select(GetStructField(CsvToStructs(schema, options, 'csv), 1)) + .select(GetStructField(CsvToStructs(schema, options, Symbol("csv")), 1)) val optimized2 = Optimizer.execute(query2.analyze) val prunedSchema2 = StructType.fromDDL("b int") val expected2 = testRelation - .select(GetStructField(CsvToStructs(schema, options, 'csv, None, Some(prunedSchema2)), 0)) + .select(GetStructField( + CsvToStructs(schema, options, Symbol("csv"), None, Some(prunedSchema2)), 0)) .analyze comparePlans(optimized2, expected2) } @@ -74,7 +76,7 @@ class OptimizeCsvExprsSuite extends PlanTest with ExpressionEvalHelper { val options = Map("mode" -> "failfast") val query = testRelation - .select(GetStructField(CsvToStructs(schema, options, 'csv), 0)) + .select(GetStructField(CsvToStructs(schema, options, Symbol("csv")), 0)) val optimized = Optimizer.execute(query.analyze) val expected = query.analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index a36083b84704..a7fb4f2f98cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -40,7 +40,7 @@ class OptimizeInSuite extends PlanTest { OptimizeIn) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) test("OptimizedIn test: Remove deterministic repetitions") { val originalQuery = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala index ccbc61e8a498..8b06c78aacb1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala @@ -44,8 +44,8 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val schema = StructType.fromDDL("a int, b int") - private val structAtt = 'struct.struct(schema).notNull - private val jsonAttr = 'json.string + private val structAtt = Symbol("struct").struct(schema).notNull + private val jsonAttr = Symbol("json").string private val testRelation = LocalRelation(structAtt) private val testRelation2 = LocalRelation(jsonAttr) @@ -54,10 +54,11 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val options = Map.empty[String, String] val query1 = testRelation - .select(JsonToStructs(schema, options, StructsToJson(options, 'struct)).as("struct")) + .select(JsonToStructs( + schema, options, StructsToJson(options, Symbol("struct"))).as("struct")) val optimized1 = Optimizer.execute(query1.analyze) - val expected = testRelation.select('struct.as("struct")).analyze + val expected = testRelation.select(Symbol("struct").as("struct")).analyze comparePlans(optimized1, expected) val query2 = testRelation @@ -65,7 +66,7 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { JsonToStructs(schema, options, StructsToJson(options, JsonToStructs(schema, options, - StructsToJson(options, 'struct)))).as("struct")) + StructsToJson(options, Symbol("struct"))))).as("struct")) val optimized2 = Optimizer.execute(query2.analyze) comparePlans(optimized2, expected) @@ -76,11 +77,11 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val schema = StructType.fromDDL("a int") val query = testRelation - .select(JsonToStructs(schema, options, StructsToJson(options, 'struct)).as("struct")) + .select(JsonToStructs(schema, options, StructsToJson(options, Symbol("struct"))).as("struct")) val optimized = Optimizer.execute(query.analyze) val expected = testRelation.select( - JsonToStructs(schema, options, StructsToJson(options, 'struct)).as("struct")).analyze + JsonToStructs(schema, options, StructsToJson(options, Symbol("struct"))).as("struct")).analyze comparePlans(optimized, expected) } @@ -90,11 +91,13 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val schema = StructType.fromDDL("a int, B int") val query = testRelation - .select(JsonToStructs(schema, options, StructsToJson(options, 'struct)).as("struct")) + .select(JsonToStructs( + schema, options, StructsToJson(options, Symbol("struct"))).as("struct")) val optimized = Optimizer.execute(query.analyze) val expected = testRelation.select( - JsonToStructs(schema, options, StructsToJson(options, 'struct)).as("struct")).analyze + JsonToStructs( + schema, options, StructsToJson(options, Symbol("struct"))).as("struct")).analyze comparePlans(optimized, expected) } } @@ -104,17 +107,17 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val nonNullSchema = StructType( StructField("a", IntegerType, false) :: StructField("b", IntegerType, false) :: Nil) - val structAtt = 'struct.struct(nonNullSchema).notNull + val structAtt = Symbol("struct").struct(nonNullSchema).notNull val testRelationWithNonNullAttr = LocalRelation(structAtt) val schema = StructType.fromDDL("a int, b int") val query = testRelationWithNonNullAttr - .select(JsonToStructs(schema, options, StructsToJson(options, 'struct)).as("struct")) + .select(JsonToStructs(schema, options, StructsToJson(options, Symbol("struct"))).as("struct")) val optimized = Optimizer.execute(query.analyze) val expected = testRelationWithNonNullAttr.select( - JsonToStructs(schema, options, StructsToJson(options, 'struct)).as("struct")).analyze + JsonToStructs(schema, options, StructsToJson(options, Symbol("struct"))).as("struct")).analyze comparePlans(optimized, expected) } @@ -122,11 +125,11 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val options = Map("testOption" -> "test") val query = testRelation - .select(JsonToStructs(schema, options, StructsToJson(options, 'struct)).as("struct")) + .select(JsonToStructs(schema, options, StructsToJson(options, Symbol("struct"))).as("struct")) val optimized = Optimizer.execute(query.analyze) val expected = testRelation.select( - JsonToStructs(schema, options, StructsToJson(options, 'struct)).as("struct")).analyze + JsonToStructs(schema, options, StructsToJson(options, Symbol("struct"))).as("struct")).analyze comparePlans(optimized, expected) } @@ -137,19 +140,19 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val query1 = testRelation .select(JsonToStructs(schema, options, - StructsToJson(options, 'struct, Option(PST.getId)), UTC_OPT).as("struct")) + StructsToJson(options, Symbol("struct"), Option(PST.getId)), UTC_OPT).as("struct")) val optimized1 = Optimizer.execute(query1.analyze) val expected1 = testRelation.select( JsonToStructs(schema, options, - StructsToJson(options, 'struct, Option(PST.getId)), UTC_OPT).as("struct")).analyze + StructsToJson(options, Symbol("struct"), Option(PST.getId)), UTC_OPT).as("struct")).analyze comparePlans(optimized1, expected1) val query2 = testRelation .select(JsonToStructs(schema, options, - StructsToJson(options, 'struct, UTC_OPT), UTC_OPT).as("struct")) + StructsToJson(options, Symbol("struct"), UTC_OPT), UTC_OPT).as("struct")) val optimized2 = Optimizer.execute(query2.analyze) - val expected2 = testRelation.select('struct.as("struct")).analyze + val expected2 = testRelation.select(Symbol("struct").as("struct")).analyze comparePlans(optimized2, expected2) } @@ -157,21 +160,21 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val options = Map.empty[String, String] val query1 = testRelation2 - .select(GetStructField(JsonToStructs(schema, options, 'json), 0)) + .select(GetStructField(JsonToStructs(schema, options, Symbol("json")), 0)) val optimized1 = Optimizer.execute(query1.analyze) val prunedSchema1 = StructType.fromDDL("a int") val expected1 = testRelation2 - .select(GetStructField(JsonToStructs(prunedSchema1, options, 'json), 0)).analyze + .select(GetStructField(JsonToStructs(prunedSchema1, options, Symbol("json")), 0)).analyze comparePlans(optimized1, expected1) val query2 = testRelation2 - .select(GetStructField(JsonToStructs(schema, options, 'json), 1)) + .select(GetStructField(JsonToStructs(schema, options, Symbol("json")), 1)) val optimized2 = Optimizer.execute(query2.analyze) val prunedSchema2 = StructType.fromDDL("b int") val expected2 = testRelation2 - .select(GetStructField(JsonToStructs(prunedSchema2, options, 'json), 0)).analyze + .select(GetStructField(JsonToStructs(prunedSchema2, options, Symbol("json")), 0)).analyze comparePlans(optimized2, expected2) } @@ -182,13 +185,13 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val query1 = testRelation2 .select(GetArrayStructFields( - JsonToStructs(schema1, options, 'json), field1, 0, 2, true).as("a")) + JsonToStructs(schema1, options, Symbol("json")), field1, 0, 2, true).as("a")) val optimized1 = Optimizer.execute(query1.analyze) val prunedSchema1 = ArrayType(StructType.fromDDL("a int"), containsNull = true) val expected1 = testRelation2 .select(GetArrayStructFields( - JsonToStructs(prunedSchema1, options, 'json), field1, 0, 1, true).as("a")).analyze + JsonToStructs(prunedSchema1, options, Symbol("json")), field1, 0, 1, true).as("a")).analyze comparePlans(optimized1, expected1) val schema2 = ArrayType( @@ -198,14 +201,14 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val field2 = schema2.elementType.asInstanceOf[StructType](1) val query2 = testRelation2 .select(GetArrayStructFields( - JsonToStructs(schema2, options, 'json), field2, 1, 2, false).as("b")) + JsonToStructs(schema2, options, Symbol("json")), field2, 1, 2, false).as("b")) val optimized2 = Optimizer.execute(query2.analyze) val prunedSchema2 = ArrayType( StructType(StructField("b", IntegerType, false) :: Nil), containsNull = false) val expected2 = testRelation2 .select(GetArrayStructFields( - JsonToStructs(prunedSchema2, options, 'json), field2, 0, 1, false).as("b")).analyze + JsonToStructs(prunedSchema2, options, Symbol("json")), field2, 0, 1, false).as("b")).analyze comparePlans(optimized2, expected2) } @@ -213,7 +216,7 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val options = Map("mode" -> "failfast") val query1 = testRelation2 - .select(GetStructField(JsonToStructs(schema, options, 'json), 0)) + .select(GetStructField(JsonToStructs(schema, options, Symbol("json")), 0)) val optimized1 = Optimizer.execute(query1.analyze) comparePlans(optimized1, query1.analyze) @@ -223,7 +226,7 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val query2 = testRelation2 .select(GetArrayStructFields( - JsonToStructs(schema1, options, 'json), field1, 0, 2, true).as("a")) + JsonToStructs(schema1, options, Symbol("json")), field1, 0, 2, true).as("a")) val optimized2 = Optimizer.execute(query2.analyze) comparePlans(optimized2, query2.analyze) @@ -237,7 +240,7 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val nullStruct = namedStruct("a", Literal(null, IntegerType), "b", Literal(null, IntegerType)) val UTC_OPT = Option("UTC") - val json: BoundReference = 'json.string.canBeNull.at(0) + val json: BoundReference = Symbol("json").string.canBeNull.at(0) assertEquivalent( testRelation2, @@ -301,7 +304,8 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { val options = Map.empty[String, String] val query = testRelation - .select(JsonToStructs(schema, options, StructsToJson(options, 'struct)).as("struct")) + .select(JsonToStructs( + schema, options, StructsToJson(options, Symbol("struct"))).as("struct")) val optimized = Optimizer.execute(query.analyze) comparePlans(optimized, query.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeLimitZeroSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeLimitZeroSuite.scala index cf875efc62c9..66581a7c5c30 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeLimitZeroSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeLimitZeroSuite.scala @@ -36,14 +36,14 @@ class OptimizeLimitZeroSuite extends PlanTest { PropagateEmptyRelation) :: Nil } - val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) - val testRelation2 = LocalRelation.fromExternalRows(Seq('b.int), data = Seq(Row(1))) + val testRelation1 = LocalRelation.fromExternalRows(Seq(Symbol("a").int), data = Seq(Row(1))) + val testRelation2 = LocalRelation.fromExternalRows(Seq(Symbol("b").int), data = Seq(Row(1))) test("Limit 0: return empty local relation") { val query = testRelation1.limit(0) val optimized = Optimize.execute(query.analyze) - val correctAnswer = LocalRelation('a.int) + val correctAnswer = LocalRelation(Symbol("a").int) comparePlans(optimized, correctAnswer) } @@ -52,7 +52,7 @@ class OptimizeLimitZeroSuite extends PlanTest { val query = LocalLimit(0, testRelation1) val optimized = Optimize.execute(query.analyze) - val correctAnswer = LocalRelation('a.int) + val correctAnswer = LocalRelation(Symbol("a").int) comparePlans(optimized, correctAnswer) } @@ -61,20 +61,23 @@ class OptimizeLimitZeroSuite extends PlanTest { val query = GlobalLimit(0, testRelation1) val optimized = Optimize.execute(query.analyze) - val correctAnswer = LocalRelation('a.int) + val correctAnswer = LocalRelation(Symbol("a").int) comparePlans(optimized, correctAnswer) } Seq( - (Inner, LocalRelation('a.int, 'b.int)), - (LeftOuter, Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze), - (RightOuter, LocalRelation('a.int, 'b.int)), - (FullOuter, Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze) + (Inner, LocalRelation(Symbol("a").int, Symbol("b").int)), + (LeftOuter, Project(Seq(Symbol("a"), + Literal(null).cast(IntegerType).as(Symbol("b"))), testRelation1).analyze), + (RightOuter, LocalRelation(Symbol("a").int, Symbol("b").int)), + (FullOuter, Project(Seq(Symbol("a"), + Literal(null).cast(IntegerType).as(Symbol("b"))), testRelation1).analyze) ).foreach { case (jt, correctAnswer) => test(s"Limit 0: for join type $jt") { val query = testRelation1 - .join(testRelation2.limit(0), joinType = jt, condition = Some('a.attr == 'b.attr)) + .join(testRelation2.limit(0), joinType = jt, + condition = Some(Symbol("a").attr == Symbol("b").attr)) val optimized = Optimize.execute(query.analyze) @@ -83,15 +86,17 @@ class OptimizeLimitZeroSuite extends PlanTest { } test("Limit 0: 3-way join") { - val testRelation3 = LocalRelation.fromExternalRows(Seq('c.int), data = Seq(Row(1))) + val testRelation3 = LocalRelation.fromExternalRows(Seq(Symbol("c").int), data = Seq(Row(1))) val subJoinQuery = testRelation1 - .join(testRelation2, joinType = Inner, condition = Some('a.attr == 'b.attr)) + .join(testRelation2, joinType = Inner, + condition = Some(Symbol("a").attr == Symbol("b").attr)) val query = subJoinQuery - .join(testRelation3.limit(0), joinType = Inner, condition = Some('a.attr == 'c.attr)) + .join(testRelation3.limit(0), joinType = Inner, + condition = Some(Symbol("a").attr == Symbol("c").attr)) val optimized = Optimize.execute(query.analyze) - val correctAnswer = LocalRelation('a.int, 'b.int, 'c.int) + val correctAnswer = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) comparePlans(optimized, correctAnswer) } @@ -101,7 +106,7 @@ class OptimizeLimitZeroSuite extends PlanTest { .intersect(testRelation1.limit(0), isAll = false) val optimized = Optimize.execute(query.analyze) - val correctAnswer = Distinct(LocalRelation('a.int)) + val correctAnswer = Distinct(LocalRelation(Symbol("a").int)) comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala index cf850bbe21ce..bd6990622f63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWindowFunctionsSuite.scala @@ -31,7 +31,7 @@ class OptimizeWindowFunctionsSuite extends PlanTest { OptimizeWindowFunctions) :: Nil } - val testRelation = LocalRelation('a.double, 'b.double, 'c.string) + val testRelation = LocalRelation(Symbol("a").double, Symbol("b").double, Symbol("c").string) val a = testRelation.output(0) val b = testRelation.output(1) val c = testRelation.output(2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala index b093b39cc4b8..416489235556 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala @@ -32,21 +32,22 @@ class OptimizeWithFieldsSuite extends PlanTest { OptimizeUpdateFields, SimplifyExtractValueOps) :: Nil } - private val testRelation = LocalRelation('a.struct('a1.int)) - private val testRelation2 = LocalRelation('a.struct('a1.int).notNull) + private val testRelation = LocalRelation(Symbol("a").struct(Symbol("a1").int)) + private val testRelation2 = LocalRelation(Symbol("a").struct(Symbol("a1").int).notNull) test("combines two adjacent UpdateFields Expressions") { val originalQuery = testRelation .select(Alias( UpdateFields( UpdateFields( - 'a, + Symbol("a"), WithField("b1", Literal(4)) :: Nil), WithField("c1", Literal(5)) :: Nil), "out")()) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) :: + .select(Alias( + UpdateFields(Symbol("a"), WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) :: Nil), "out")()) .analyze @@ -59,14 +60,15 @@ class OptimizeWithFieldsSuite extends PlanTest { UpdateFields( UpdateFields( UpdateFields( - 'a, + Symbol("a"), WithField("b1", Literal(4)) :: Nil), WithField("c1", Literal(5)) :: Nil), WithField("d1", Literal(6)) :: Nil), "out")()) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) :: + .select(Alias(UpdateFields( + Symbol("a"), WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) :: WithField("d1", Literal(6)) :: Nil), "out")()) .analyze @@ -76,7 +78,7 @@ class OptimizeWithFieldsSuite extends PlanTest { test("SPARK-32941: optimize WithFields followed by GetStructField") { val originalQuery = testRelation2 .select(Alias( - GetStructField(UpdateFields('a, + GetStructField(UpdateFields(Symbol("a"), WithField("b1", Literal(4)) :: Nil), 1), "out")()) val optimized = Optimize.execute(originalQuery.analyze) @@ -90,16 +92,16 @@ class OptimizeWithFieldsSuite extends PlanTest { test("SPARK-32941: optimize WithFields chain - case insensitive") { val originalQuery = testRelation .select( - Alias(UpdateFields('a, + Alias(UpdateFields(Symbol("a"), WithField("b1", Literal(4)) :: WithField("b1", Literal(5)) :: Nil), "out1")(), - Alias(UpdateFields('a, + Alias(UpdateFields(Symbol("a"), WithField("b1", Literal(4)) :: WithField("B1", Literal(5)) :: Nil), "out2")()) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select( - Alias(UpdateFields('a, WithField("b1", Literal(5)) :: Nil), "out1")(), - Alias(UpdateFields('a, WithField("B1", Literal(5)) :: Nil), "out2")()) + Alias(UpdateFields(Symbol("a"), WithField("b1", Literal(5)) :: Nil), "out1")(), + Alias(UpdateFields(Symbol("a"), WithField("B1", Literal(5)) :: Nil), "out2")()) .analyze comparePlans(optimized, correctAnswer) @@ -109,17 +111,17 @@ class OptimizeWithFieldsSuite extends PlanTest { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { val originalQuery = testRelation .select( - Alias(UpdateFields('a, + Alias(UpdateFields(Symbol("a"), WithField("b1", Literal(4)) :: WithField("b1", Literal(5)) :: Nil), "out1")(), - Alias(UpdateFields('a, + Alias(UpdateFields(Symbol("a"), WithField("b1", Literal(4)) :: WithField("B1", Literal(5)) :: Nil), "out2")()) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation .select( - Alias(UpdateFields('a, WithField("b1", Literal(5)) :: Nil), "out1")(), + Alias(UpdateFields(Symbol("a"), WithField("b1", Literal(5)) :: Nil), "out1")(), Alias( - UpdateFields('a, + UpdateFields(Symbol("a"), WithField("b1", Literal(4)) :: WithField("B1", Literal(5)) :: Nil), "out2")()) .analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerLoggingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerLoggingSuite.scala index 1187950c0424..7089bad9bf30 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerLoggingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerLoggingSuite.scala @@ -39,9 +39,10 @@ class OptimizerLoggingSuite extends PlanTest { private def verifyLog(expectedLevel: Level, expectedRulesOrBatches: Seq[String]): Unit = { val logAppender = new LogAppender("optimizer rules") withLogAppender(logAppender, level = Some(Level.TRACE)) { - val input = LocalRelation('a.int, 'b.string, 'c.double) - val query = input.select('a, 'b).select('a).where('a > 1).analyze - val expected = input.where('a > 1).select('a).analyze + val input = LocalRelation(Symbol("a").int, Symbol("b").string, Symbol("c").double) + val query = + input.select(Symbol("a"), Symbol("b")).select(Symbol("a")).where(Symbol("a") > 1).analyze + val expected = input.where(Symbol("a") > 1).select(Symbol("a")).analyze comparePlans(Optimize.execute(query), expected) } val events = logAppender.loggingEvents.filter { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala index a277a2d339e9..f8917ab1f9c5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_EXCLUDED_RULES class OptimizerRuleExclusionSuite extends PlanTest { - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) private def verifyExcludedRules(optimizer: Optimizer, rulesToExclude: Seq[String]): Unit = { val nonExcludableRules = optimizer.nonExcludableRules @@ -121,9 +121,9 @@ class OptimizerRuleExclusionSuite extends PlanTest { PropagateEmptyRelation.ruleName, CombineUnions.ruleName) - val testRelation1 = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation3 = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation1 = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) + val testRelation2 = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) + val testRelation3 = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) withSQLConf( OPTIMIZER_EXCLUDED_RULES.key -> excludedRules.foldLeft("")((l, r) => l + "," + r)) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala index f4a52180373c..72021bceb7a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerStructuralIntegrityCheckerSuite.scala @@ -62,8 +62,8 @@ class OptimizerStructuralIntegrityCheckerSuite extends PlanTest { test("check for invalid plan after execution of rule - special expression in wrong operator") { val analyzed = - Aggregate(Nil, Seq[NamedExpression](max('id) as 'm), - LocalRelation('id.long)).analyze + Aggregate(Nil, Seq[NamedExpression](max(Symbol("id")) as Symbol("m")), + LocalRelation(Symbol("id").long)).analyze assert(analyzed.resolved) // Should fail verification with the OptimizeRuleBreakSI rule @@ -80,8 +80,8 @@ class OptimizerStructuralIntegrityCheckerSuite extends PlanTest { test("check for invalid plan before execution of any rule") { val analyzed = - Aggregate(Nil, Seq[NamedExpression](max('id) as 'm), - LocalRelation('id.long)).analyze + Aggregate(Nil, Seq[NamedExpression](max(Symbol("id")) as Symbol("m")), + LocalRelation(Symbol("id").long)).analyze val invalidPlan = OptimizeRuleBreakSI.apply(analyzed) // Should fail verification right at the beginning diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala index 893c111c2906..8ff03faa82a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OuterJoinEliminationSuite.scala @@ -36,20 +36,20 @@ class OuterJoinEliminationSuite extends PlanTest { PushPredicateThroughJoin) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation1 = LocalRelation('d.int, 'e.int, 'f.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) + val testRelation1 = LocalRelation(Symbol("d").int, Symbol("e").int, Symbol("f").int) test("joins: full outer to inner") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) val originalQuery = x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) .where("x.b".attr >= 1 && "y.d".attr >= 2) val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('b >= 1) - val right = testRelation1.where('d >= 2) + val left = testRelation.where(Symbol("b") >= 1) + val right = testRelation1.where(Symbol("d") >= 2) val correctAnswer = left.join(right, Inner, Option("a".attr === "d".attr)).analyze @@ -57,15 +57,15 @@ class OuterJoinEliminationSuite extends PlanTest { } test("joins: full outer to right") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) val originalQuery = x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)).where("y.d".attr > 2) val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation - val right = testRelation1.where('d > 2) + val right = testRelation1.where(Symbol("d") > 2) val correctAnswer = left.join(right, RightOuter, Option("a".attr === "d".attr)).analyze @@ -73,14 +73,14 @@ class OuterJoinEliminationSuite extends PlanTest { } test("joins: full outer to left") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) val originalQuery = x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)).where("x.a".attr <=> 2) val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('a <=> 2) + val left = testRelation.where(Symbol("a") <=> 2) val right = testRelation1 val correctAnswer = left.join(right, LeftOuter, Option("a".attr === "d".attr)).analyze @@ -89,14 +89,14 @@ class OuterJoinEliminationSuite extends PlanTest { } test("joins: right to inner") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) val originalQuery = x.join(y, RightOuter, Option("x.a".attr === "y.d".attr)).where("x.b".attr > 2) val optimized = Optimize.execute(originalQuery.analyze) - val left = testRelation.where('b > 2) + val left = testRelation.where(Symbol("b") > 2) val right = testRelation1 val correctAnswer = left.join(right, Inner, Option("a".attr === "d".attr)).analyze @@ -105,8 +105,8 @@ class OuterJoinEliminationSuite extends PlanTest { } test("joins: left to inner") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) val originalQuery = x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr)) @@ -114,7 +114,7 @@ class OuterJoinEliminationSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation - val right = testRelation1.where('e.isNotNull) + val right = testRelation1.where(Symbol("e").isNotNull) val correctAnswer = left.join(right, Inner, Option("a".attr === "d".attr)).analyze @@ -123,16 +123,17 @@ class OuterJoinEliminationSuite extends PlanTest { // evaluating if mixed OR and NOT expressions can eliminate all null-supplying rows test("joins: left to inner with complicated filter predicates #1") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) val originalQuery = x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr)) - .where(!'e.isNull || ('d.isNotNull && 'f.isNull)) + .where(!Symbol("e").isNull || (Symbol("d").isNotNull && Symbol("f").isNull)) val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation - val right = testRelation1.where(!'e.isNull || ('d.isNotNull && 'f.isNull)) + val right = + testRelation1.where(!Symbol("e").isNull || (Symbol("d").isNotNull && Symbol("f").isNull)) val correctAnswer = left.join(right, Inner, Option("a".attr === "d".attr)).analyze @@ -141,16 +142,16 @@ class OuterJoinEliminationSuite extends PlanTest { // eval(emptyRow) of 'e.in(1, 2) will return null instead of false test("joins: left to inner with complicated filter predicates #2") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) val originalQuery = x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr)) - .where('e.in(1, 2)) + .where(Symbol("e").in(1, 2)) val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation - val right = testRelation1.where('e.in(1, 2)) + val right = testRelation1.where(Symbol("e").in(1, 2)) val correctAnswer = left.join(right, Inner, Option("a".attr === "d".attr)).analyze @@ -159,16 +160,18 @@ class OuterJoinEliminationSuite extends PlanTest { // evaluating if mixed OR and AND expressions can eliminate all null-supplying rows test("joins: left to inner with complicated filter predicates #3") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) val originalQuery = x.join(y, LeftOuter, Option("x.a".attr === "y.d".attr)) - .where((!'e.isNull || ('d.isNotNull && 'f.isNull)) && 'e.isNull) + .where((!Symbol("e").isNull || (Symbol("d").isNotNull && + Symbol("f").isNull)) && Symbol("e").isNull) val optimized = Optimize.execute(originalQuery.analyze) val left = testRelation - val right = testRelation1.where((!'e.isNull || ('d.isNotNull && 'f.isNull)) && 'e.isNull) + val right = testRelation1.where( + (!Symbol("e").isNull || (Symbol("d").isNotNull && Symbol("f").isNull)) && Symbol("e").isNull) val correctAnswer = left.join(right, Inner, Option("a".attr === "d".attr)).analyze @@ -179,8 +182,8 @@ class OuterJoinEliminationSuite extends PlanTest { // can eliminate all null-supplying rows // FULL OUTER => INNER test("joins: left to inner with complicated filter predicates #4") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) val originalQuery = x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) @@ -196,8 +199,8 @@ class OuterJoinEliminationSuite extends PlanTest { } test("joins: no outer join elimination if the filter is not NULL eliminated") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) val originalQuery = x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) @@ -215,8 +218,8 @@ class OuterJoinEliminationSuite extends PlanTest { } test("joins: no outer join elimination if the filter's constraints are not NULL eliminated") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) val originalQuery = x.join(y, FullOuter, Option("x.a".attr === "y.d".attr)) @@ -235,8 +238,8 @@ class OuterJoinEliminationSuite extends PlanTest { test("no outer join elimination if constraint propagation is disabled") { withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { - val x = testRelation.subquery('x) - val y = testRelation1.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation1.subquery(Symbol("y")) // The predicate "x.b + y.d >= 3" will be inferred constraints like: // "x.b != null" and "y.d != null", if constraint propagation is enabled. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 5c980abdd8f5..d64ca05c4ae9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -53,11 +53,11 @@ class PropagateEmptyRelationSuite extends PlanTest { CollapseProject) :: Nil } - val testRelation1 = LocalRelation.fromExternalRows(Seq('a.int), data = Seq(Row(1))) - val testRelation2 = LocalRelation.fromExternalRows(Seq('b.int), data = Seq(Row(1))) + val testRelation1 = LocalRelation.fromExternalRows(Seq(Symbol("a").int), data = Seq(Row(1))) + val testRelation2 = LocalRelation.fromExternalRows(Seq(Symbol("b").int), data = Seq(Row(1))) val metadata = new MetadataBuilder().putLong("test", 1).build() - val testRelation3 = - LocalRelation.fromExternalRows(Seq('c.int.notNull.withMetadata(metadata)), data = Seq(Row(1))) + val testRelation3 = LocalRelation.fromExternalRows( + Seq(Symbol("c").int.notNull.withMetadata(metadata)), data = Seq(Row(1))) test("propagate empty relation through Union") { val query = testRelation1 @@ -65,7 +65,7 @@ class PropagateEmptyRelationSuite extends PlanTest { .union(testRelation2.where(false)) val optimized = Optimize.execute(query.analyze) - val correctAnswer = LocalRelation('a.int) + val correctAnswer = LocalRelation(Symbol("a").int) comparePlans(optimized, correctAnswer) } @@ -78,7 +78,7 @@ class PropagateEmptyRelationSuite extends PlanTest { val query2 = testRelation1.where(false).union(testRelation2) val optimized2 = Optimize.execute(query2.analyze) - val correctAnswer2 = testRelation2.select('b.as('a)).analyze + val correctAnswer2 = testRelation2.select(Symbol("b").as(Symbol("a"))).analyze comparePlans(optimized2, correctAnswer2) val query3 = testRelation1.union(testRelation2.where(false)).union(testRelation3) @@ -88,7 +88,8 @@ class PropagateEmptyRelationSuite extends PlanTest { val query4 = testRelation1.where(false).union(testRelation2).union(testRelation3) val optimized4 = Optimize.execute(query4.analyze) - val correctAnswer4 = testRelation2.union(testRelation3).select('b.as('a)).analyze + val correctAnswer4 = + testRelation2.union(testRelation3).select(Symbol("b").as(Symbol("a"))).analyze comparePlans(optimized4, correctAnswer4) // Nullability can change from nullable to non-nullable @@ -115,39 +116,40 @@ class PropagateEmptyRelationSuite extends PlanTest { (true, true, LeftAnti, None), (true, true, LeftSemi, None), - (true, false, Inner, Some(LocalRelation('a.int, 'b.int))), - (true, false, Cross, Some(LocalRelation('a.int, 'b.int))), - (true, false, LeftOuter, - Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)), - (true, false, RightOuter, Some(LocalRelation('a.int, 'b.int))), - (true, false, FullOuter, - Some(Project(Seq('a, Literal(null).cast(IntegerType).as('b)), testRelation1).analyze)), + (true, false, Inner, Some(LocalRelation(Symbol("a").int, Symbol("b").int))), + (true, false, Cross, Some(LocalRelation(Symbol("a").int, Symbol("b").int))), + (true, false, LeftOuter, Some(Project(Seq(Symbol("a"), + Literal(null).cast(IntegerType).as(Symbol("b"))), testRelation1).analyze)), + (true, false, RightOuter, Some(LocalRelation(Symbol("a").int, Symbol("b").int))), + (true, false, FullOuter, Some(Project(Seq(Symbol("a"), + Literal(null).cast(IntegerType).as(Symbol("b"))), testRelation1).analyze)), (true, false, LeftAnti, Some(testRelation1)), - (true, false, LeftSemi, Some(LocalRelation('a.int))), - - (false, true, Inner, Some(LocalRelation('a.int, 'b.int))), - (false, true, Cross, Some(LocalRelation('a.int, 'b.int))), - (false, true, LeftOuter, Some(LocalRelation('a.int, 'b.int))), - (false, true, RightOuter, - Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)), - (false, true, FullOuter, - Some(Project(Seq(Literal(null).cast(IntegerType).as('a), 'b), testRelation2).analyze)), - (false, true, LeftAnti, Some(LocalRelation('a.int))), - (false, true, LeftSemi, Some(LocalRelation('a.int))), - - (false, false, Inner, Some(LocalRelation('a.int, 'b.int))), - (false, false, Cross, Some(LocalRelation('a.int, 'b.int))), - (false, false, LeftOuter, Some(LocalRelation('a.int, 'b.int))), - (false, false, RightOuter, Some(LocalRelation('a.int, 'b.int))), - (false, false, FullOuter, Some(LocalRelation('a.int, 'b.int))), - (false, false, LeftAnti, Some(LocalRelation('a.int))), - (false, false, LeftSemi, Some(LocalRelation('a.int))) + (true, false, LeftSemi, Some(LocalRelation(Symbol("a").int))), + + (false, true, Inner, Some(LocalRelation(Symbol("a").int, Symbol("b").int))), + (false, true, Cross, Some(LocalRelation(Symbol("a").int, Symbol("b").int))), + (false, true, LeftOuter, Some(LocalRelation(Symbol("a").int, Symbol("b").int))), + (false, true, RightOuter, Some(Project(Seq(Literal(null).cast(IntegerType).as(Symbol("a")), + Symbol("b")), testRelation2).analyze)), + (false, true, FullOuter, Some(Project(Seq(Literal(null).cast(IntegerType).as(Symbol("a")), + Symbol("b")), testRelation2).analyze)), + (false, true, LeftAnti, Some(LocalRelation(Symbol("a").int))), + (false, true, LeftSemi, Some(LocalRelation(Symbol("a").int))), + + (false, false, Inner, Some(LocalRelation(Symbol("a").int, Symbol("b").int))), + (false, false, Cross, Some(LocalRelation(Symbol("a").int, Symbol("b").int))), + (false, false, LeftOuter, Some(LocalRelation(Symbol("a").int, Symbol("b").int))), + (false, false, RightOuter, Some(LocalRelation(Symbol("a").int, Symbol("b").int))), + (false, false, FullOuter, Some(LocalRelation(Symbol("a").int, Symbol("b").int))), + (false, false, LeftAnti, Some(LocalRelation(Symbol("a").int))), + (false, false, LeftSemi, Some(LocalRelation(Symbol("a").int))) ) testcases.foreach { case (left, right, jt, answer) => val query = testRelation1 .where(left) - .join(testRelation2.where(right), joinType = jt, condition = Some('a.attr == 'b.attr)) + .join(testRelation2.where(right), joinType = jt, + condition = Some(Symbol("a").attr == Symbol("b").attr)) val optimized = Optimize.execute(query.analyze) val correctAnswer = answer.getOrElse(OptimizeWithoutPropagateEmptyRelation.execute(query.analyze)) @@ -158,19 +160,19 @@ class PropagateEmptyRelationSuite extends PlanTest { test("propagate empty relation through UnaryNode") { val query = testRelation1 .where(false) - .select('a) - .groupBy('a)('a) - .where('a > 1) - .orderBy('a.asc) + .select(Symbol("a")) + .groupBy(Symbol("a"))(Symbol("a")) + .where(Symbol("a") > 1) + .orderBy(Symbol("a").asc) val optimized = Optimize.execute(query.analyze) - val correctAnswer = LocalRelation('a.int) + val correctAnswer = LocalRelation(Symbol("a").int) comparePlans(optimized, correctAnswer) } test("propagate empty streaming relation through multiple UnaryNode") { - val output = Seq('a.int) + val output = Seq(Symbol("a").int) val data = Seq(Row(1)) val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) @@ -181,10 +183,10 @@ class PropagateEmptyRelationSuite extends PlanTest { val query = relation .where(false) - .select('a) - .where('a > 1) - .where('a =!= 200) - .orderBy('a.asc) + .select(Symbol("a")) + .where(Symbol("a") > 1) + .where(Symbol("a") =!= 200) + .orderBy(Symbol("a").asc) val optimized = Optimize.execute(query.analyze) val correctAnswer = LocalRelation(output, isStreaming = true) @@ -193,7 +195,7 @@ class PropagateEmptyRelationSuite extends PlanTest { } test("don't propagate empty streaming relation through agg") { - val output = Seq('a.int) + val output = Seq(Symbol("a").int) val data = Seq(Row(1)) val schema = StructType.fromAttributes(output) val converter = CatalystTypeConverters.createToCatalystConverter(schema) @@ -203,7 +205,7 @@ class PropagateEmptyRelationSuite extends PlanTest { isStreaming = true) val query = relation - .groupBy('a)('a) + .groupBy(Symbol("a"))(Symbol("a")) val optimized = Optimize.execute(query.analyze) val correctAnswer = query.analyze @@ -214,17 +216,17 @@ class PropagateEmptyRelationSuite extends PlanTest { test("don't propagate non-empty local relation") { val query = testRelation1 .where(true) - .groupBy('a)('a) - .where('a > 1) - .orderBy('a.asc) - .select('a) + .groupBy(Symbol("a"))(Symbol("a")) + .where(Symbol("a") > 1) + .orderBy(Symbol("a").asc) + .select(Symbol("a")) val optimized = Optimize.execute(query.analyze) val correctAnswer = testRelation1 - .where('a > 1) - .groupBy('a)('a) - .orderBy('a.asc) - .select('a) + .where(Symbol("a") > 1) + .groupBy(Symbol("a"))(Symbol("a")) + .orderBy(Symbol("a").asc) + .select(Symbol("a")) comparePlans(optimized, correctAnswer.analyze) } @@ -232,10 +234,10 @@ class PropagateEmptyRelationSuite extends PlanTest { test("propagate empty relation through Aggregate with grouping expressions") { val query = testRelation1 .where(false) - .groupBy('a)('a, ('a + 1).as('x)) + .groupBy(Symbol("a"))(Symbol("a"), (Symbol("a") + 1).as(Symbol("x"))) val optimized = Optimize.execute(query.analyze) - val correctAnswer = LocalRelation('a.int, 'x.int).analyze + val correctAnswer = LocalRelation(Symbol("a").int, Symbol("x").int).analyze comparePlans(optimized, correctAnswer) } @@ -246,14 +248,14 @@ class PropagateEmptyRelationSuite extends PlanTest { .groupBy()() val optimized = Optimize.execute(query.analyze) - val correctAnswer = LocalRelation('a.int).groupBy()().analyze + val correctAnswer = LocalRelation(Symbol("a").int).groupBy()().analyze comparePlans(optimized, correctAnswer) } test("propagate empty relation keeps the plan resolved") { val query = testRelation1.join( - LocalRelation('a.int, 'b.int), UsingJoin(FullOuter, "a" :: Nil), None) + LocalRelation(Symbol("a").int, Symbol("b").int), UsingJoin(FullOuter, "a" :: Nil), None) val optimized = Optimize.execute(query.analyze) assert(optimized.resolved) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index 15fbe3c5b0a1..0929f351db92 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -39,11 +39,11 @@ class PruneFiltersSuite extends PlanTest { PushPredicateThroughJoin) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) test("Constraints of isNull + LeftOuter") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val query = x.where("x.b".attr.isNull).join(y, LeftOuter) val queryWithUselessFilter = query.where("x.b".attr.isNull) @@ -55,15 +55,15 @@ class PruneFiltersSuite extends PlanTest { } test("Constraints of unionall") { - val tr1 = LocalRelation('a.int, 'b.int, 'c.int) - val tr2 = LocalRelation('d.int, 'e.int, 'f.int) - val tr3 = LocalRelation('g.int, 'h.int, 'i.int) + val tr1 = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) + val tr2 = LocalRelation(Symbol("d").int, Symbol("e").int, Symbol("f").int) + val tr3 = LocalRelation(Symbol("g").int, Symbol("h").int, Symbol("i").int) val query = - tr1.where('a.attr > 10) - .union(tr2.where('d.attr > 10) - .union(tr3.where('g.attr > 10))) - val queryWithUselessFilter = query.where('a.attr > 10) + tr1.where(Symbol("a").attr > 10) + .union(tr2.where(Symbol("d").attr > 10) + .union(tr3.where(Symbol("g").attr > 10))) + val queryWithUselessFilter = query.where(Symbol("a").attr > 10) val optimized = Optimize.execute(queryWithUselessFilter.analyze) val correctAnswer = query.analyze @@ -72,17 +72,19 @@ class PruneFiltersSuite extends PlanTest { } test("Pruning multiple constraints in the same run") { - val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) - val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + val tr1 = + LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int).subquery(Symbol("tr1")) + val tr2 = + LocalRelation(Symbol("a").int, Symbol("d").int, Symbol("e").int).subquery(Symbol("tr2")) val query = tr1 .where("tr1.a".attr > 10 || "tr1.c".attr < 10) - .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) + .join(tr2.where(Symbol("d").attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) // different order of "tr2.a" and "tr1.a" val queryWithUselessFilter = query.where( ("tr1.a".attr > 10 || "tr1.c".attr < 10) && - 'd.attr < 100 && + Symbol("d").attr < 100 && "tr2.a".attr === "tr1.a".attr) val optimized = Optimize.execute(queryWithUselessFilter.analyze) @@ -92,21 +94,23 @@ class PruneFiltersSuite extends PlanTest { } test("Partial pruning") { - val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) - val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + val tr1 = + LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int).subquery(Symbol("tr1")) + val tr2 = + LocalRelation(Symbol("a").int, Symbol("d").int, Symbol("e").int).subquery(Symbol("tr2")) // One of the filter condition does not exist in the constraints of its child // Thus, the filter is not removed val query = tr1 .where("tr1.a".attr > 10) - .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.d".attr)) + .join(tr2.where(Symbol("d").attr < 100), Inner, Some("tr1.a".attr === "tr2.d".attr)) val queryWithExtraFilters = - query.where("tr1.a".attr > 10 && 'd.attr < 100 && "tr1.a".attr === "tr2.a".attr) + query.where("tr1.a".attr > 10 && Symbol("d").attr < 100 && "tr1.a".attr === "tr2.a".attr) val optimized = Optimize.execute(queryWithExtraFilters.analyze) val correctAnswer = tr1 .where("tr1.a".attr > 10) - .join(tr2.where('d.attr < 100), + .join(tr2.where(Symbol("d").attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr && "tr1.a".attr === "tr2.d".attr)).analyze @@ -114,8 +118,8 @@ class PruneFiltersSuite extends PlanTest { } test("No predicate is pruned") { - val x = testRelation.subquery('x) - val y = testRelation.subquery('y) + val x = testRelation.subquery(Symbol("x")) + val y = testRelation.subquery(Symbol("y")) val query = x.where("x.b".attr.isNull).join(y, LeftOuter) val queryWithExtraFilters = query.where("x.b".attr.isNotNull) @@ -129,24 +133,28 @@ class PruneFiltersSuite extends PlanTest { } test("Nondeterministic predicate is not pruned") { - val originalQuery = testRelation.where(Rand(10) > 5).select('a).where(Rand(10) > 5).analyze + val originalQuery = + testRelation.where(Rand(10) > 5).select(Symbol("a")).where(Rand(10) > 5).analyze val optimized = Optimize.execute(originalQuery) - val correctAnswer = testRelation.where(Rand(10) > 5).where(Rand(10) > 5).select('a).analyze + val correctAnswer = + testRelation.where(Rand(10) > 5).where(Rand(10) > 5).select(Symbol("a")).analyze comparePlans(optimized, correctAnswer) } test("No pruning when constraint propagation is disabled") { - val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) - val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + val tr1 = + LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int).subquery(Symbol("tr1")) + val tr2 = + LocalRelation(Symbol("a").int, Symbol("d").int, Symbol("e").int).subquery(Symbol("tr2")) val query = tr1 .where("tr1.a".attr > 10 || "tr1.c".attr < 10) - .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) + .join(tr2.where(Symbol("d").attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) val queryWithUselessFilter = query.where( ("tr1.a".attr > 10 || "tr1.c".attr < 10) && - 'd.attr < 100) + Symbol("d").attr < 100) withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "false") { val optimized = Optimize.execute(queryWithUselessFilter.analyze) @@ -155,7 +163,7 @@ class PruneFiltersSuite extends PlanTest { // and duplicate filters. val correctAnswer = tr1 .where("tr1.a".attr > 10 || "tr1.c".attr < 10).where("tr1.a".attr > 10 || "tr1.c".attr < 10) - .join(tr2.where('d.attr < 100).where('d.attr < 100), + .join(tr2.where(Symbol("d").attr < 100).where(Symbol("d").attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala index ae9a694b5044..c44ac123f0e4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesSuite.scala @@ -34,18 +34,18 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { PullupCorrelatedPredicates) :: Nil } - val testRelation = LocalRelation('a.int, 'b.double) - val testRelation2 = LocalRelation('c.int, 'd.double) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").double) + val testRelation2 = LocalRelation(Symbol("c").int, Symbol("d").double) test("PullupCorrelatedPredicates should not produce unresolved plan") { val subPlan = testRelation2 - .where('b < 'd) - .select('c) + .where(Symbol("b") < Symbol("d")) + .select(Symbol("c")) val inSubquery = testRelation - .where(InSubquery(Seq('a), ListQuery(subPlan))) - .select('a).analyze + .where(InSubquery(Seq(Symbol("a")), ListQuery(subPlan))) + .select(Symbol("a")).analyze assert(inSubquery.resolved) val optimized = Optimize.execute(inSubquery) @@ -55,12 +55,12 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { test("PullupCorrelatedPredicates in correlated subquery idempotency check") { val subPlan = testRelation2 - .where('b < 'd) - .select('c) + .where(Symbol("b") < Symbol("d")) + .select(Symbol("c")) val inSubquery = testRelation - .where(InSubquery(Seq('a), ListQuery(subPlan))) - .select('a).analyze + .where(InSubquery(Seq(Symbol("a")), ListQuery(subPlan))) + .select(Symbol("a")).analyze assert(inSubquery.resolved) val optimized = Optimize.execute(inSubquery) @@ -71,12 +71,12 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { test("PullupCorrelatedPredicates exists correlated subquery idempotency check") { val subPlan = testRelation2 - .where('b === 'd && 'd === 1) + .where(Symbol("b") === Symbol("d") && Symbol("d") === 1) .select(Literal(1)) val existsSubquery = testRelation .where(Exists(subPlan)) - .select('a).analyze + .select(Symbol("a")).analyze assert(existsSubquery.resolved) val optimized = Optimize.execute(existsSubquery) @@ -87,12 +87,12 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { test("PullupCorrelatedPredicates scalar correlated subquery idempotency check") { val subPlan = testRelation2 - .where('b === 'd && 'd === 1) - .select(max('d)) + .where(Symbol("b") === Symbol("d") && Symbol("d") === 1) + .select(max(Symbol("d"))) val scalarSubquery = testRelation .where(ScalarSubquery(subPlan) === 1) - .select('a).analyze + .select(Symbol("a")).analyze val optimized = Optimize.execute(scalarSubquery) val doubleOptimized = Optimize.execute(optimized) @@ -100,8 +100,8 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { } test("PullupCorrelatedPredicates should handle deletes") { - val subPlan = testRelation2.where('a === 'c).select('c) - val cond = InSubquery(Seq('a), ListQuery(subPlan)) + val subPlan = testRelation2.where(Symbol("a") === Symbol("c")).select(Symbol("c")) + val cond = InSubquery(Seq(Symbol("a")), ListQuery(subPlan)) val deletePlan = DeleteFromTable(testRelation, Some(cond)).analyze assert(deletePlan.resolved) @@ -118,8 +118,8 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { } test("PullupCorrelatedPredicates should handle updates") { - val subPlan = testRelation2.where('a === 'c).select('c) - val cond = InSubquery(Seq('a), ListQuery(subPlan)) + val subPlan = testRelation2.where(Symbol("a") === Symbol("c")).select(Symbol("c")) + val cond = InSubquery(Seq(Symbol("a")), ListQuery(subPlan)) val updatePlan = UpdateTable(testRelation, Seq.empty, Some(cond)).analyze assert(updatePlan.resolved) @@ -136,16 +136,17 @@ class PullupCorrelatedPredicatesSuite extends PlanTest { } test("PullupCorrelatedPredicates should handle merge") { - val testRelation3 = LocalRelation('e.int, 'f.double) - val subPlan = testRelation3.where('a === 'e).select('e) - val cond = InSubquery(Seq('a), ListQuery(subPlan)) + val testRelation3 = LocalRelation(Symbol("e").int, Symbol("f").double) + val subPlan = testRelation3.where(Symbol("a") === Symbol("e")).select(Symbol("e")) + val cond = InSubquery(Seq(Symbol("a")), ListQuery(subPlan)) val mergePlan = MergeIntoTable( testRelation, testRelation2, cond, Seq(DeleteAction(None)), - Seq(InsertAction(None, Seq(Assignment('a, 'c), Assignment('b, 'd))))) + Seq(InsertAction(None, Seq(Assignment(Symbol("a"), + Symbol("c")), Assignment(Symbol("b"), Symbol("d")))))) val analyzedMergePlan = mergePlan.analyze assert(analyzedMergePlan.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index cb90a398604f..ba7745ec264b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -38,7 +38,7 @@ class PushFoldableIntoBranchesSuite BooleanSimplification, ConstantFolding, SimplifyConditionals, PushFoldableIntoBranches) :: Nil } - private val relation = LocalRelation('a.int, 'b.int, 'c.boolean) + private val relation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").boolean) private val a = EqualTo(UnresolvedAttribute("a"), Literal(100)) private val b = UnresolvedAttribute("b") private val c = EqualTo(UnresolvedAttribute("c"), Literal(true)) @@ -271,10 +271,10 @@ class PushFoldableIntoBranchesSuite test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") { assertEquivalent( - EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(0)), - 'a > 10 <=> TrueLiteral) + EqualTo(CaseWhen(Seq((Symbol("a") > 10, Literal(0))), Literal(1)), Literal(0)), + Symbol("a") > 10 <=> TrueLiteral) assertEquivalent( - EqualTo(CaseWhen(Seq(('a > 10, Literal(0))), Literal(1)), Literal(1)), - Not('a > 10 <=> TrueLiteral)) + EqualTo(CaseWhen(Seq((Symbol("a") > 10, Literal(0))), Literal(1)), Literal(1)), + Not(Symbol("a") > 10 <=> TrueLiteral)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectThroughUnionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectThroughUnionSuite.scala index 294d29842b04..634a7fb9fca9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectThroughUnionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushProjectThroughUnionSuite.scala @@ -33,19 +33,19 @@ class PushProjectThroughUnionSuite extends PlanTest { test("SPARK-25450 PushProjectThroughUnion rule uses the same exprId for project expressions " + "in each Union child, causing mistakes in constant propagation") { - val testRelation1 = LocalRelation('a.string, 'b.int, 'c.string) - val testRelation2 = LocalRelation('d.string, 'e.int, 'f.string) + val testRelation1 = LocalRelation(Symbol("a").string, Symbol("b").int, Symbol("c").string) + val testRelation2 = LocalRelation(Symbol("d").string, Symbol("e").int, Symbol("f").string) val query = testRelation1 - .union(testRelation2.select("bar".as("d"), 'e, 'f)) - .select('a.as("n")) - .select('n, "dummy").analyze + .union(testRelation2.select("bar".as("d"), Symbol("e"), Symbol("f"))) + .select(Symbol("a").as("n")) + .select(Symbol("n"), "dummy").analyze val optimized = Optimize.execute(query) val expected = testRelation1 - .select('a.as("n")) - .select('n, "dummy") + .select(Symbol("a").as("n")) + .select(Symbol("n"), "dummy") .union(testRelation2 - .select("bar".as("d"), 'e, 'f) + .select("bar".as("d"), Symbol("e"), Symbol("f")) .select("bar".as("n")) .select("bar".as("n"), "dummy")).analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReassignLambdaVariableIDSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReassignLambdaVariableIDSuite.scala index 06a32c77ac5e..148217b36797 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReassignLambdaVariableIDSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReassignLambdaVariableIDSuite.scala @@ -32,7 +32,7 @@ class ReassignLambdaVariableIDSuite extends PlanTest { } test("basic: replace positive IDs with unique negative IDs") { - val testRelation = LocalRelation('col.int) + val testRelation = LocalRelation(Symbol("col").int) val var1 = LambdaVariable("a", BooleanType, true, id = 2) val var2 = LambdaVariable("b", BooleanType, true, id = 4) val query = testRelation.where(var1 && var2) @@ -42,7 +42,7 @@ class ReassignLambdaVariableIDSuite extends PlanTest { } test("ignore LambdaVariable with negative IDs") { - val testRelation = LocalRelation('col.int) + val testRelation = LocalRelation(Symbol("col").int) val var1 = LambdaVariable("a", BooleanType, true, id = -2) val var2 = LambdaVariable("b", BooleanType, true, id = -4) val query = testRelation.where(var1 && var2) @@ -51,7 +51,7 @@ class ReassignLambdaVariableIDSuite extends PlanTest { } test("fail if positive ID LambdaVariable and negative LambdaVariable both exist") { - val testRelation = LocalRelation('col.int) + val testRelation = LocalRelation(Symbol("col").int) val var1 = LambdaVariable("a", BooleanType, true, id = -2) val var2 = LambdaVariable("b", BooleanType, true, id = 4) val query = testRelation.where(var1 && var2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveNoopOperatorsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveNoopOperatorsSuite.scala index cedd21d2bf52..614cda3018df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveNoopOperatorsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveNoopOperatorsSuite.scala @@ -31,12 +31,12 @@ class RemoveNoopOperatorsSuite extends PlanTest { RemoveNoopOperators) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) test("Remove all redundant projections in one iteration") { val originalQuery = testRelation - .select('a, 'b, 'c) - .select('a, 'b, 'c) + .select(Symbol("a"), Symbol("b"), Symbol("c")) + .select(Symbol("a"), Symbol("b"), Symbol("c")) .analyze val optimized = Optimize.execute(originalQuery.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala index 2e0ab7f64f4d..0757d0723a5e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala @@ -37,95 +37,100 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest with PredicateHelper } test("all expressions in project list are aliased child output") { - val relation = LocalRelation('a.int, 'b.int) - val query = relation.select('a as 'a, 'b as 'b).analyze + val relation = LocalRelation(Symbol("a").int, Symbol("b").int) + val query = relation.select(Symbol("a") as Symbol("a"), Symbol("b") as Symbol("b")).analyze val optimized = Optimize.execute(query) comparePlans(optimized, relation) } test("all expressions in project list are aliased child output but with different order") { - val relation = LocalRelation('a.int, 'b.int) - val query = relation.select('b as 'b, 'a as 'a).analyze + val relation = LocalRelation(Symbol("a").int, Symbol("b").int) + val query = relation.select(Symbol("b") as Symbol("b"), Symbol("a") as Symbol("a")).analyze val optimized = Optimize.execute(query) - val expected = relation.select('b, 'a).analyze + val expected = relation.select(Symbol("b"), Symbol("a")).analyze comparePlans(optimized, expected) } test("some expressions in project list are aliased child output") { - val relation = LocalRelation('a.int, 'b.int) - val query = relation.select('a as 'a, 'b).analyze + val relation = LocalRelation(Symbol("a").int, Symbol("b").int) + val query = relation.select(Symbol("a") as Symbol("a"), Symbol("b")).analyze val optimized = Optimize.execute(query) comparePlans(optimized, relation) } test("some expressions in project list are aliased child output but with different order") { - val relation = LocalRelation('a.int, 'b.int) - val query = relation.select('b as 'b, 'a).analyze + val relation = LocalRelation(Symbol("a").int, Symbol("b").int) + val query = relation.select(Symbol("b") as Symbol("b"), Symbol("a")).analyze val optimized = Optimize.execute(query) - val expected = relation.select('b, 'a).analyze + val expected = relation.select(Symbol("b"), Symbol("a")).analyze comparePlans(optimized, expected) } test("some expressions in project list are not Alias or Attribute") { - val relation = LocalRelation('a.int, 'b.int) - val query = relation.select('a as 'a, 'b + 1).analyze + val relation = LocalRelation(Symbol("a").int, Symbol("b").int) + val query = relation.select(Symbol("a") as Symbol("a"), Symbol("b") + 1).analyze val optimized = Optimize.execute(query) - val expected = relation.select('a, 'b + 1).analyze + val expected = relation.select(Symbol("a"), Symbol("b") + 1).analyze comparePlans(optimized, expected) } test("some expressions in project list are aliased child output but with metadata") { - val relation = LocalRelation('a.int, 'b.int) + val relation = LocalRelation(Symbol("a").int, Symbol("b").int) val metadata = new MetadataBuilder().putString("x", "y").build() - val aliasWithMeta = Alias('a, "a")(explicitMetadata = Some(metadata)) - val query = relation.select(aliasWithMeta, 'b).analyze + val aliasWithMeta = Alias(Symbol("a"), "a")(explicitMetadata = Some(metadata)) + val query = relation.select(aliasWithMeta, Symbol("b")).analyze val optimized = Optimize.execute(query) comparePlans(optimized, query) } test("retain deduplicating alias in self-join") { - val relation = LocalRelation('a.int) - val fragment = relation.select('a as 'a) - val query = fragment.select('a as 'a).join(fragment.select('a as 'a)).analyze + val relation = LocalRelation(Symbol("a").int) + val fragment = relation.select(Symbol("a") as Symbol("a")) + val query = fragment.select(Symbol("a") as Symbol("a")) + .join(fragment.select(Symbol("a") as Symbol("a"))).analyze val optimized = Optimize.execute(query) - val expected = relation.join(relation.select('a as 'a)).analyze + val expected = relation.join(relation.select(Symbol("a") as Symbol("a"))).analyze comparePlans(optimized, expected) } test("alias removal should not break after push project through union") { - val r1 = LocalRelation('a.int) - val r2 = LocalRelation('b.int) - val query = r1.select('a as 'a).union(r2.select('b as 'b)).select('a).analyze + val r1 = LocalRelation(Symbol("a").int) + val r2 = LocalRelation(Symbol("b").int) + val query = r1.select(Symbol("a") as Symbol("a")) + .union(r2.select(Symbol("b") as Symbol("b"))).select(Symbol("a")).analyze val optimized = Optimize.execute(query) val expected = r1.union(r2) comparePlans(optimized, expected) } test("remove redundant alias from aggregate") { - val relation = LocalRelation('a.int, 'b.int) - val query = relation.groupBy('a as 'a)('a as 'a, sum('b)).analyze + val relation = LocalRelation(Symbol("a").int, Symbol("b").int) + val query = relation + .groupBy(Symbol("a") as Symbol("a"))(Symbol("a") as Symbol("a"), sum(Symbol("b"))).analyze val optimized = Optimize.execute(query) - val expected = relation.groupBy('a)('a, sum('b)).analyze + val expected = relation.groupBy(Symbol("a"))(Symbol("a"), sum(Symbol("b"))).analyze comparePlans(optimized, expected) } test("remove redundant alias from window") { - val relation = LocalRelation('a.int, 'b.int) - val query = relation.window(Seq('b as 'b), Seq('a as 'a), Seq()).analyze + val relation = LocalRelation(Symbol("a").int, Symbol("b").int) + val query = relation + .window(Seq(Symbol("b") as Symbol("b")), Seq(Symbol("a") as Symbol("a")), Seq()).analyze val optimized = Optimize.execute(query) - val expected = relation.window(Seq('b), Seq('a), Seq()).analyze + val expected = relation.window(Seq(Symbol("b")), Seq(Symbol("a")), Seq()).analyze comparePlans(optimized, expected) } test("do not remove output attributes from a subquery") { - val relation = LocalRelation('a.int, 'b.int) - val query = Subquery( - relation.select('a as "a", 'b as "b").where('b < 10).select('a).analyze, + val relation = LocalRelation(Symbol("a").int, Symbol("b").int) + val query = Subquery(relation + .select(Symbol("a") as "a", Symbol("b") as "b") + .where(Symbol("b") < 10).select(Symbol("a")).analyze, correlated = false) val optimized = Optimize.execute(query) - val expected = Subquery( - relation.select('a as "a", 'b).where('b < 10).select('a).analyze, - correlated = false) + val expected = Subquery(relation + .select(Symbol("a") as "a", Symbol("b")) + .where(Symbol("b") < 10).select(Symbol("a")).analyze, correlated = false) comparePlans(optimized, expected) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala index a1ab0a834474..20282d78ca0e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala @@ -32,17 +32,17 @@ class ReorderAssociativeOperatorSuite extends PlanTest { ReorderAssociativeOperator) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) test("Reorder associative operators") { val originalQuery = testRelation .select( - (Literal(3) + ((Literal(1) + 'a) + 2)) + 4, - 'b * 1 * 2 * 3 * 4, - ('b + 1) * 2 * 3 * 4, - 'a + 1 + 'b + 2 + 'c + 3, - 'a + 1 + 'b * 2 + 'c + 3, + (Literal(3) + ((Literal(1) + Symbol("a")) + 2)) + 4, + Symbol("b") * 1 * 2 * 3 * 4, + (Symbol("b") + 1) * 2 * 3 * 4, + Symbol("a") + 1 + Symbol("b") + 2 + Symbol("c") + 3, + Symbol("a") + 1 + Symbol("b") * 2 + Symbol("c") + 3, Rand(0) * 1 * 2 * 3 * 4) val optimized = Optimize.execute(originalQuery.analyze) @@ -50,11 +50,11 @@ class ReorderAssociativeOperatorSuite extends PlanTest { val correctAnswer = testRelation .select( - ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"), - ('b * 24).as("((((b * 1) * 2) * 3) * 4)"), - (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"), - ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"), - ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"), + (Symbol("a") + 10).as("((3 + ((1 + a) + 2)) + 4)"), + (Symbol("b") * 24).as("((((b * 1) * 2) * 3) * 4)"), + ((Symbol("b") + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"), + (Symbol("a") + Symbol("b") + Symbol("c") + 6).as("(((((a + 1) + b) + 2) + c) + 3)"), + (Symbol("a") + Symbol("b") * 2 + Symbol("c") + 4).as("((((a + 1) + (b * 2)) + c) + 3)"), Rand(0) * 1 * 2 * 3 * 4) .analyze diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index ffab358721e1..fda457e88067 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -42,8 +42,9 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { } private val testRelation = - LocalRelation('i.int, 'b.boolean, 'a.array(IntegerType), 'm.map(IntegerType, IntegerType)) - private val anotherTestRelation = LocalRelation('d.int) + LocalRelation(Symbol("i").int, Symbol("b").boolean, + Symbol("a").array(IntegerType), Symbol("m").map(IntegerType, IntegerType)) + private val anotherTestRelation = LocalRelation(Symbol("d").int) test("replace null inside filter and join conditions") { testFilter(originalCond = Literal(null, BooleanType), expectedCond = FalseLiteral) @@ -351,33 +352,33 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name)) test("replace nulls in lambda function of ArrayFilter") { - testHigherOrderFunc('a, ArrayFilter, Seq(lv('e))) + testHigherOrderFunc(Symbol("a"), ArrayFilter, Seq(lv(Symbol("e")))) } test("replace nulls in lambda function of ArrayExists") { withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key -> "true") { - val lambdaArgs = Seq(lv('e)) + val lambdaArgs = Seq(lv(Symbol("e"))) val cond = GreaterThan(lambdaArgs.last, Literal(0)) val lambda = LambdaFunction( function = If(cond, Literal(null, BooleanType), TrueLiteral), arguments = lambdaArgs) - val expr = ArrayExists('a, lambda) + val expr = ArrayExists(Symbol("a"), lambda) testProjection(originalExpr = expr, expectedExpr = expr) } withSQLConf(SQLConf.LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC.key -> "false") { - testHigherOrderFunc('a, ArrayExists.apply, Seq(lv('e))) + testHigherOrderFunc(Symbol("a"), ArrayExists.apply, Seq(lv(Symbol("e")))) } } test("replace nulls in lambda function of MapFilter") { - testHigherOrderFunc('m, MapFilter, Seq(lv('k), lv('v))) + testHigherOrderFunc(Symbol("m"), MapFilter, Seq(lv(Symbol("k")), lv(Symbol("v")))) } test("inability to replace nulls in arbitrary higher-order function") { val lambdaFunc = LambdaFunction( - function = If(lv('e) > 0, Literal(null, BooleanType), TrueLiteral), - arguments = Seq[NamedExpression](lv('e))) - val column = ArrayTransform('a, lambdaFunc) + function = If(lv(Symbol("e")) > 0, Literal(null, BooleanType), TrueLiteral), + arguments = Seq[NamedExpression](lv(Symbol("e")))) + val column = ArrayTransform(Symbol("a"), lambdaFunc) testProjection(originalExpr = column, expectedExpr = column) } @@ -449,8 +450,8 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { function = !(cond <=> TrueLiteral), arguments = lambdaArgs) testProjection( - originalExpr = createExpr(argument, lambda1) as 'x, - expectedExpr = createExpr(argument, lambda2) as 'x) + originalExpr = createExpr(argument, lambda1) as Symbol("x"), + expectedExpr = createExpr(argument, lambda2) as Symbol("x")) } private def test( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala index 9bf864f5201f..de75a6558ad2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala @@ -40,22 +40,23 @@ class ReplaceOperatorSuite extends PlanTest { } test("replace Intersect with Left-semi Join") { - val table1 = LocalRelation('a.int, 'b.int) - val table2 = LocalRelation('c.int, 'd.int) + val table1 = LocalRelation(Symbol("a").int, Symbol("b").int) + val table2 = LocalRelation(Symbol("c").int, Symbol("d").int) val query = Intersect(table1, table2, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = Aggregate(table1.output, table1.output, - Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd), JoinHint.NONE)).analyze + Join(table1, table2, LeftSemi, Option(Symbol("a") <=> Symbol("c") && + Symbol("b") <=> Symbol("d")), JoinHint.NONE)).analyze comparePlans(optimized, correctAnswer) } test("replace Except with Filter while both the nodes are of type Filter") { - val attributeA = 'a.int - val attributeB = 'b.int + val attributeA = Symbol("a").int + val attributeB = Symbol("b").int val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) val table2 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) @@ -73,8 +74,8 @@ class ReplaceOperatorSuite extends PlanTest { } test("replace Except with Filter while only right node is of type Filter") { - val attributeA = 'a.int - val attributeB = 'b.int + val attributeA = Symbol("a").int + val attributeB = Symbol("b").int val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) val table2 = Filter(attributeB < 1, Filter(attributeA >= 2, table1)) @@ -91,8 +92,8 @@ class ReplaceOperatorSuite extends PlanTest { } test("replace Except with Filter while both the nodes are of type Project") { - val attributeA = 'a.int - val attributeB = 'b.int + val attributeA = Symbol("a").int + val attributeB = Symbol("b").int val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) val table2 = Project(Seq(attributeA, attributeB), table1) @@ -111,8 +112,8 @@ class ReplaceOperatorSuite extends PlanTest { } test("replace Except with Filter while only right node is of type Project") { - val attributeA = 'a.int - val attributeB = 'b.int + val attributeA = Symbol("a").int + val attributeB = Symbol("b").int val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) val table2 = Filter(attributeB === 2, Filter(attributeA === 1, table1)) @@ -131,8 +132,8 @@ class ReplaceOperatorSuite extends PlanTest { } test("replace Except with Filter while left node is Project and right node is Filter") { - val attributeA = 'a.int - val attributeB = 'b.int + val attributeA = Symbol("a").int + val attributeB = Symbol("b").int val table1 = LocalRelation.fromExternalRows(Seq(attributeA, attributeB), data = Seq(Row(1, 2))) val table2 = Project(Seq(attributeA, attributeB), @@ -152,23 +153,24 @@ class ReplaceOperatorSuite extends PlanTest { } test("replace Except with Left-anti Join") { - val table1 = LocalRelation('a.int, 'b.int) - val table2 = LocalRelation('c.int, 'd.int) + val table1 = LocalRelation(Symbol("a").int, Symbol("b").int) + val table2 = LocalRelation(Symbol("c").int, Symbol("d").int) val query = Except(table1, table2, isAll = false) val optimized = Optimize.execute(query.analyze) val correctAnswer = Aggregate(table1.output, table1.output, - Join(table1, table2, LeftAnti, Option('a <=> 'c && 'b <=> 'd), JoinHint.NONE)).analyze + Join(table1, table2, LeftAnti, Option(Symbol("a") <=> Symbol("c") && + Symbol("b") <=> Symbol("d")), JoinHint.NONE)).analyze comparePlans(optimized, correctAnswer) } test("replace Except with Filter when only right filter can be applied to the left") { - val table = LocalRelation(Seq('a.int, 'b.int)) - val left = table.where('b < 1).select('a).as("left") - val right = table.where('b < 3).select('a).as("right") + val table = LocalRelation(Seq(Symbol("a").int, Symbol("b").int)) + val left = table.where(Symbol("b") < 1).select(Symbol("a")).as("left") + val right = table.where(Symbol("b") < 3).select(Symbol("a")).as("right") val query = Except(left, right, isAll = false) val optimized = Optimize.execute(query.analyze) @@ -181,7 +183,7 @@ class ReplaceOperatorSuite extends PlanTest { } test("replace Distinct with Aggregate") { - val input = LocalRelation('a.int, 'b.int) + val input = LocalRelation(Symbol("a").int, Symbol("b").int) val query = Distinct(input) val optimized = Optimize.execute(query.analyze) @@ -192,7 +194,7 @@ class ReplaceOperatorSuite extends PlanTest { } test("replace batch Deduplicate with Aggregate") { - val input = LocalRelation('a.int, 'b.int) + val input = LocalRelation(Symbol("a").int, Symbol("b").int) val attrA = input.output(0) val attrB = input.output(1) val query = Deduplicate(Seq(attrA), input) // dropDuplicates("a") @@ -219,7 +221,7 @@ class ReplaceOperatorSuite extends PlanTest { } test("don't replace streaming Deduplicate") { - val input = LocalRelation(Seq('a.int, 'b.int), isStreaming = true) + val input = LocalRelation(Seq(Symbol("a").int, Symbol("b").int), isStreaming = true) val attrA = input.output(0) val query = Deduplicate(Seq(attrA), input) // dropDuplicates("a") val optimized = Optimize.execute(query.analyze) @@ -228,21 +230,22 @@ class ReplaceOperatorSuite extends PlanTest { } test("SPARK-26366: ReplaceExceptWithFilter should handle properly NULL") { - val basePlan = LocalRelation(Seq('a.int, 'b.int)) - val otherPlan = basePlan.where('a.in(1, 2) || 'b.in()) + val basePlan = LocalRelation(Seq(Symbol("a").int, Symbol("b").int)) + val otherPlan = basePlan.where(Symbol("a").in(1, 2) || Symbol("b").in()) val except = Except(basePlan, otherPlan, false) val result = OptimizeIn(Optimize.execute(except.analyze)) val correctAnswer = Aggregate(basePlan.output, basePlan.output, Filter(!Coalesce(Seq( - 'a.in(1, 2) || If('b.isNotNull, Literal.FalseLiteral, Literal(null, BooleanType)), + Symbol("a").in(1, 2) || If(Symbol("b").isNotNull, + Literal.FalseLiteral, Literal(null, BooleanType)), Literal.FalseLiteral)), basePlan)).analyze comparePlans(result, correctAnswer) } test("SPARK-26366: ReplaceExceptWithFilter should not transform non-deterministic") { - val basePlan = LocalRelation(Seq('a.int, 'b.int)) - val otherPlan = basePlan.where('a > rand(1L)) + val basePlan = LocalRelation(Seq(Symbol("a").int, Symbol("b").int)) + val otherPlan = basePlan.where(Symbol("a") > rand(1L)) val except = Except(basePlan, otherPlan, false) val result = Optimize.execute(except.analyze) val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 5d6abf516f28..07d919fafb35 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -27,7 +27,8 @@ import org.apache.spark.sql.types.{IntegerType, StringType} class RewriteDistinctAggregatesSuite extends PlanTest { val nullInt = Literal(null, IntegerType) val nullString = Literal(null, StringType) - val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int) + val testRelation = LocalRelation(Symbol("a").string, + Symbol("b").string, Symbol("c").string, Symbol("d").string, Symbol("e").int) private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match { case Aggregate(_, _, Aggregate(_, _, _: Expand)) => @@ -36,7 +37,7 @@ class RewriteDistinctAggregatesSuite extends PlanTest { test("single distinct group") { val input = testRelation - .groupBy('a)(countDistinct('e)) + .groupBy(Symbol("a"))(countDistinct(Symbol("e"))) .analyze val rewrite = RewriteDistinctAggregates(input) comparePlans(input, rewrite) @@ -44,9 +45,9 @@ class RewriteDistinctAggregatesSuite extends PlanTest { test("single distinct group with partial aggregates") { val input = testRelation - .groupBy('a, 'd)( - countDistinct('e, 'c).as('agg1), - max('b).as('agg2)) + .groupBy(Symbol("a"), Symbol("d"))( + countDistinct(Symbol("e"), Symbol("c")).as(Symbol("agg1")), + max(Symbol("b")).as(Symbol("agg2"))) .analyze val rewrite = RewriteDistinctAggregates(input) comparePlans(input, rewrite) @@ -54,24 +55,25 @@ class RewriteDistinctAggregatesSuite extends PlanTest { test("multiple distinct groups") { val input = testRelation - .groupBy('a)(countDistinct('b, 'c), countDistinct('d)) + .groupBy(Symbol("a"))(countDistinct(Symbol("b"), Symbol("c")), countDistinct(Symbol("d"))) .analyze checkRewrite(RewriteDistinctAggregates(input)) } test("multiple distinct groups with partial aggregates") { val input = testRelation - .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e)) + .groupBy(Symbol("a"))(countDistinct(Symbol("b"), + Symbol("c")), countDistinct(Symbol("d")), sum(Symbol("e"))) .analyze checkRewrite(RewriteDistinctAggregates(input)) } test("multiple distinct groups with non-partial aggregates") { val input = testRelation - .groupBy('a)( - countDistinct('b, 'c), - countDistinct('d), - CollectSet('b).toAggregateExpression()) + .groupBy(Symbol("a"))( + countDistinct(Symbol("b"), Symbol("c")), + countDistinct(Symbol("d")), + CollectSet(Symbol("b")).toAggregateExpression()) .analyze checkRewrite(RewriteDistinctAggregates(input)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala index 0608ded73937..923694b86ee6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala @@ -38,30 +38,33 @@ class RewriteSubquerySuite extends PlanTest { } test("Column pruning after rewriting predicate subquery") { - val relation = LocalRelation('a.int, 'b.int) - val relInSubquery = LocalRelation('x.int, 'y.int, 'z.int) + val relation = LocalRelation(Symbol("a").int, Symbol("b").int) + val relInSubquery = LocalRelation(Symbol("x").int, Symbol("y").int, Symbol("z").int) - val query = relation.where('a.in(ListQuery(relInSubquery.select('x)))).select('a) + val query = relation.where(Symbol("a").in( + ListQuery(relInSubquery.select(Symbol("x"))))).select(Symbol("a")) val optimized = Optimize.execute(query.analyze) val correctAnswer = relation - .select('a) - .join(relInSubquery.select('x), LeftSemi, Some('a === 'x)) + .select(Symbol("a")) + .join(relInSubquery.select(Symbol("x")), LeftSemi, Some(Symbol("a") === Symbol("x"))) .analyze comparePlans(optimized, correctAnswer) } test("NOT-IN subquery nested inside OR") { - val relation1 = LocalRelation('a.int, 'b.int) - val relation2 = LocalRelation('c.int, 'd.int) - val exists = 'exists.boolean.notNull + val relation1 = LocalRelation(Symbol("a").int, Symbol("b").int) + val relation2 = LocalRelation(Symbol("c").int, Symbol("d").int) + val exists = Symbol("exists").boolean.notNull - val query = relation1.where('b === 1 || Not('a.in(ListQuery(relation2.select('c))))).select('a) + val query = relation1.where(Symbol("b") === 1 || + Not(Symbol("a").in(ListQuery(relation2.select(Symbol("c")))))).select(Symbol("a")) val correctAnswer = relation1 - .join(relation2.select('c), ExistenceJoin(exists), Some('a === 'c || IsNull('a === 'c))) - .where('b === 1 || Not(exists)) - .select('a) + .join(relation2.select(Symbol("c")), ExistenceJoin(exists), + Some(Symbol("a") === Symbol("c") || IsNull(Symbol("a") === Symbol("c")))) + .where(Symbol("b") === 1 || Not(exists)) + .select(Symbol("a")) .analyze val optimized = Optimize.execute(query.analyze) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index 8543b62fd8bd..177558eb097a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -38,9 +38,9 @@ class SetOperationSuite extends PlanTest { PruneFilters) :: Nil } - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) - val testRelation3 = LocalRelation('g.int, 'h.int, 'i.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) + val testRelation2 = LocalRelation(Symbol("d").int, Symbol("e").int, Symbol("f").int) + val testRelation3 = LocalRelation(Symbol("g").int, Symbol("h").int, Symbol("i").int) val testUnion = Union(testRelation :: testRelation2 :: testRelation3 :: Nil) test("union: combine unions into one unions") { @@ -59,33 +59,33 @@ class SetOperationSuite extends PlanTest { } test("union: filter to each side") { - val unionQuery = testUnion.where('a === 1) + val unionQuery = testUnion.where(Symbol("a") === 1) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Union(testRelation.where('a === 1) :: - testRelation2.where('d === 1) :: - testRelation3.where('g === 1) :: Nil).analyze + Union(testRelation.where(Symbol("a") === 1) :: + testRelation2.where(Symbol("d") === 1) :: + testRelation3.where(Symbol("g") === 1) :: Nil).analyze comparePlans(unionOptimized, unionCorrectAnswer) } test("union: project to each side") { - val unionQuery = testUnion.select('a) + val unionQuery = testUnion.select(Symbol("a")) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Union(testRelation.select('a) :: - testRelation2.select('d) :: - testRelation3.select('g) :: Nil).analyze + Union(testRelation.select(Symbol("a")) :: + testRelation2.select(Symbol("d")) :: + testRelation3.select(Symbol("g")) :: Nil).analyze comparePlans(unionOptimized, unionCorrectAnswer) } test("Remove unnecessary distincts in multiple unions") { val query1 = OneRowRelation() - .select(Literal(1).as('a)) + .select(Literal(1).as(Symbol("a"))) val query2 = OneRowRelation() - .select(Literal(2).as('b)) + .select(Literal(2).as(Symbol("b"))) val query3 = OneRowRelation() - .select(Literal(3).as('c)) + .select(Literal(3).as(Symbol("c"))) // D - U - D - U - query1 // | | @@ -113,13 +113,13 @@ class SetOperationSuite extends PlanTest { test("Keep necessary distincts in multiple unions") { val query1 = OneRowRelation() - .select(Literal(1).as('a)) + .select(Literal(1).as(Symbol("a"))) val query2 = OneRowRelation() - .select(Literal(2).as('b)) + .select(Literal(2).as(Symbol("b"))) val query3 = OneRowRelation() - .select(Literal(3).as('c)) + .select(Literal(3).as(Symbol("c"))) val query4 = OneRowRelation() - .select(Literal(4).as('d)) + .select(Literal(4).as(Symbol("d"))) // U - D - U - query1 // | | @@ -150,10 +150,12 @@ class SetOperationSuite extends PlanTest { val input = Except(testRelation, testRelation2, isAll = true) val rewrittenPlan = RewriteExceptAll(input) - val planFragment = testRelation.select(Literal(1L).as("vcol"), 'a, 'b, 'c) - .union(testRelation2.select(Literal(-1L).as("vcol"), 'd, 'e, 'f)) - .groupBy('a, 'b, 'c)('a, 'b, 'c, sum('vcol).as("sum")) - .where(GreaterThan('sum, Literal(0L))).analyze + val planFragment = testRelation + .select(Literal(1L).as("vcol"), Symbol("a"), Symbol("b"), Symbol("c")) + .union(testRelation2.select(Literal(-1L).as("vcol"), Symbol("d"), Symbol("e"), Symbol("f"))) + .groupBy(Symbol("a"), Symbol("b"), Symbol("c"))( + Symbol("a"), Symbol("b"), Symbol("c"), sum(Symbol("vcol")).as("sum")) + .where(GreaterThan(Symbol("sum"), Literal(0L))).analyze val multiplierAttr = planFragment.output.last val output = planFragment.output.dropRight(1) val expectedPlan = Project(output, @@ -172,16 +174,19 @@ class SetOperationSuite extends PlanTest { val input = Intersect(testRelation, testRelation2, isAll = true) val rewrittenPlan = RewriteIntersectAll(input) val leftRelation = testRelation - .select(Literal(true).as("vcol1"), Literal(null, BooleanType).as("vcol2"), 'a, 'b, 'c) + .select(Literal(true).as("vcol1"), + Literal(null, BooleanType).as("vcol2"), Symbol("a"), Symbol("b"), Symbol("c")) val rightRelation = testRelation2 - .select(Literal(null, BooleanType).as("vcol1"), Literal(true).as("vcol2"), 'd, 'e, 'f) + .select(Literal(null, BooleanType).as("vcol1"), + Literal(true).as("vcol2"), Symbol("d"), Symbol("e"), Symbol("f")) val planFragment = leftRelation.union(rightRelation) - .groupBy('a, 'b, 'c)(count('vcol1).as("vcol1_count"), - count('vcol2).as("vcol2_count"), 'a, 'b, 'c) - .where(And(GreaterThanOrEqual('vcol1_count, Literal(1L)), - GreaterThanOrEqual('vcol2_count, Literal(1L)))) - .select('a, 'b, 'c, - If(GreaterThan('vcol1_count, 'vcol2_count), 'vcol2_count, 'vcol1_count).as("min_count")) + .groupBy(Symbol("a"), Symbol("b"), Symbol("c"))(count(Symbol("vcol1")).as("vcol1_count"), + count(Symbol("vcol2")).as("vcol2_count"), Symbol("a"), Symbol("b"), Symbol("c")) + .where(And(GreaterThanOrEqual(Symbol("vcol1_count"), Literal(1L)), + GreaterThanOrEqual(Symbol("vcol2_count"), Literal(1L)))) + .select(Symbol("a"), Symbol("b"), Symbol("c"), + If(GreaterThan(Symbol("vcol1_count"), Symbol("vcol2_count")), + Symbol("vcol2_count"), Symbol("vcol1_count")).as("min_count")) .analyze val multiplierAttr = planFragment.output.last val output = planFragment.output.dropRight(1) @@ -198,27 +203,27 @@ class SetOperationSuite extends PlanTest { } test("SPARK-23356 union: expressions with literal in project list are pushed down") { - val unionQuery = testUnion.select(('a + 1).as("aa")) + val unionQuery = testUnion.select((Symbol("a") + 1).as("aa")) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Union(testRelation.select(('a + 1).as("aa")) :: - testRelation2.select(('d + 1).as("aa")) :: - testRelation3.select(('g + 1).as("aa")) :: Nil).analyze + Union(testRelation.select((Symbol("a") + 1).as("aa")) :: + testRelation2.select((Symbol("d") + 1).as("aa")) :: + testRelation3.select((Symbol("g") + 1).as("aa")) :: Nil).analyze comparePlans(unionOptimized, unionCorrectAnswer) } test("SPARK-23356 union: expressions in project list are pushed down") { - val unionQuery = testUnion.select(('a + 'b).as("ab")) + val unionQuery = testUnion.select((Symbol("a") + Symbol("b")).as("ab")) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = - Union(testRelation.select(('a + 'b).as("ab")) :: - testRelation2.select(('d + 'e).as("ab")) :: - testRelation3.select(('g + 'h).as("ab")) :: Nil).analyze + Union(testRelation.select((Symbol("a") + Symbol("b")).as("ab")) :: + testRelation2.select((Symbol("d") + Symbol("e")).as("ab")) :: + testRelation3.select((Symbol("g") + Symbol("h")).as("ab")) :: Nil).analyze comparePlans(unionOptimized, unionCorrectAnswer) } test("SPARK-23356 union: no pushdown for non-deterministic expression") { - val unionQuery = testUnion.select('a, Rand(10).as("rnd")) + val unionQuery = testUnion.select(Symbol("a"), Rand(10).as("rnd")) val unionOptimized = Optimize.execute(unionQuery.analyze) val unionCorrectAnswer = unionQuery.analyze comparePlans(unionOptimized, unionCorrectAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala index c981cee55d0f..6dcd6f21bbfc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCastsSuite.scala @@ -31,15 +31,15 @@ class SimplifyCastsSuite extends PlanTest { } test("non-nullable element array to nullable element array cast") { - val input = LocalRelation('a.array(ArrayType(IntegerType, false))) - val plan = input.select('a.cast(ArrayType(IntegerType, true)).as("casted")).analyze + val input = LocalRelation(Symbol("a").array(ArrayType(IntegerType, false))) + val plan = input.select(Symbol("a").cast(ArrayType(IntegerType, true)).as("casted")).analyze val optimized = Optimize.execute(plan) - val expected = input.select('a.as("casted")).analyze + val expected = input.select(Symbol("a").as("casted")).analyze comparePlans(optimized, expected) } test("nullable element to non-nullable element array cast") { - val input = LocalRelation('a.array(ArrayType(IntegerType, true))) + val input = LocalRelation(Symbol("a").array(ArrayType(IntegerType, true))) val attr = input.output.head val plan = input.select(attr.cast(ArrayType(IntegerType, false)).as("casted")) val optimized = Optimize.execute(plan) @@ -49,16 +49,16 @@ class SimplifyCastsSuite extends PlanTest { } test("non-nullable value map to nullable value map cast") { - val input = LocalRelation('m.map(MapType(StringType, StringType, false))) - val plan = input.select('m.cast(MapType(StringType, StringType, true)) + val input = LocalRelation(Symbol("m").map(MapType(StringType, StringType, false))) + val plan = input.select(Symbol("m").cast(MapType(StringType, StringType, true)) .as("casted")).analyze val optimized = Optimize.execute(plan) - val expected = input.select('m.as("casted")).analyze + val expected = input.select(Symbol("m").as("casted")).analyze comparePlans(optimized, expected) } test("nullable value map to non-nullable value map cast") { - val input = LocalRelation('m.map(MapType(StringType, StringType, true))) + val input = LocalRelation(Symbol("m").map(MapType(StringType, StringType, true))) val attr = input.output.head val plan = input.select(attr.cast(MapType(StringType, StringType, false)) .as("casted")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 2a685bfeefcb..dbb3b07a7200 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -35,7 +35,7 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P BooleanSimplification, ConstantFolding, SimplifyConditionals) :: Nil } - private val relation = LocalRelation('a.int, 'b.int, 'c.boolean) + private val relation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").boolean) protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { val correctAnswer = Project(Alias(e2, "out")() :: Nil, relation).analyze @@ -126,9 +126,9 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P test("simplify CaseWhen if all the outputs are semantic equivalence") { // When the conditions in `CaseWhen` are all deterministic, `CaseWhen` can be removed. assertEquivalent( - CaseWhen(('a.isNotNull, Subtract(Literal(3), Literal(2))) :: - ('b.isNull, Literal(1)) :: - (!'c, Add(Literal(6), Literal(-5))) :: + CaseWhen((Symbol("a").isNotNull, Subtract(Literal(3), Literal(2))) :: + (Symbol("b").isNull, Literal(1)) :: + (!Symbol("c"), Add(Literal(6), Literal(-5))) :: Nil, Add(Literal(2), Literal(-1))), Literal(1) @@ -167,19 +167,19 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P } test("simplify if when one clause is null and another is boolean") { - val p = IsNull('a) + val p = IsNull(Symbol("a")) val nullLiteral = Literal(null, BooleanType) assertEquivalent(If(p, nullLiteral, FalseLiteral), And(p, nullLiteral)) - assertEquivalent(If(p, nullLiteral, TrueLiteral), Or(IsNotNull('a), nullLiteral)) - assertEquivalent(If(p, FalseLiteral, nullLiteral), And(IsNotNull('a), nullLiteral)) + assertEquivalent(If(p, nullLiteral, TrueLiteral), Or(IsNotNull(Symbol("a")), nullLiteral)) + assertEquivalent(If(p, FalseLiteral, nullLiteral), And(IsNotNull(Symbol("a")), nullLiteral)) assertEquivalent(If(p, TrueLiteral, nullLiteral), Or(p, nullLiteral)) // the rule should not apply to nullable predicate Seq(TrueLiteral, FalseLiteral).foreach { b => - assertEquivalent(If(GreaterThan('a, 42), nullLiteral, b), - If(GreaterThan('a, 42), nullLiteral, b)) - assertEquivalent(If(GreaterThan('a, 42), b, nullLiteral), - If(GreaterThan('a, 42), b, nullLiteral)) + assertEquivalent(If(GreaterThan(Symbol("a"), 42), nullLiteral, b), + If(GreaterThan(Symbol("a"), 42), nullLiteral, b)) + assertEquivalent(If(GreaterThan(Symbol("a"), 42), b, nullLiteral), + If(GreaterThan(Symbol("a"), 42), b, nullLiteral)) } // check evaluation also @@ -203,10 +203,10 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P test("SPARK-33845: remove unnecessary if when the outputs are boolean type") { // verify the boolean equivalence of all transformations involved val fields = Seq( - 'cond.boolean.notNull, - 'cond_nullable.boolean, - 'a.boolean, - 'b.boolean + Symbol("cond").boolean.notNull, + Symbol("cond_nullable").boolean, + Symbol("a").boolean, + Symbol("b").boolean ) val Seq(cond, cond_nullable, a, b) = fields.zipWithIndex.map { case (f, i) => f.at(i) } @@ -238,7 +238,7 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P test("SPARK-33847: Remove the CaseWhen if elseValue is empty and other outputs are null") { assertEquivalent( - CaseWhen((GreaterThan('a, 1), Literal.create(null, IntegerType)) :: Nil, None), + CaseWhen((GreaterThan(Symbol("a"), 1), Literal.create(null, IntegerType)) :: Nil, None), Literal.create(null, IntegerType)) assertEquivalent( @@ -249,10 +249,10 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P test("SPARK-33884: simplify CaseWhen clauses with (true and false) and (false and true)") { // verify the boolean equivalence of all transformations involved val fields = Seq( - 'cond.boolean.notNull, - 'cond_nullable.boolean, - 'a.boolean, - 'b.boolean + Symbol("cond").boolean.notNull, + Symbol("cond_nullable").boolean, + Symbol("a").boolean, + Symbol("b").boolean ) val Seq(cond, cond_nullable, a, b) = fields.zipWithIndex.map { case (f, i) => f.at(i) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalsInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalsInPredicateSuite.scala index 04ebb4e63c67..6e1300377159 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalsInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalsInPredicateSuite.scala @@ -41,8 +41,9 @@ class SimplifyConditionalsInPredicateSuite extends PlanTest { } private val testRelation = - LocalRelation('i.int, 'b.boolean, 'a.array(IntegerType), 'm.map(IntegerType, IntegerType)) - private val anotherTestRelation = LocalRelation('d.int) + LocalRelation(Symbol("i").int, Symbol("b").boolean, + Symbol("a").array(IntegerType), Symbol("m").map(IntegerType, IntegerType)) + private val anotherTestRelation = LocalRelation(Symbol("d").int) test("IF(cond, trueVal, false) => AND(cond, trueVal)") { val originalCond = If( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala index b9bf930f0ea0..4768e7a1a0c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyStringCaseConversionSuite.scala @@ -32,17 +32,17 @@ class SimplifyStringCaseConversionSuite extends PlanTest { SimplifyCaseConversionExpressions) :: Nil } - val testRelation = LocalRelation('a.string) + val testRelation = LocalRelation(Symbol("a").string) test("simplify UPPER(UPPER(str))") { val originalQuery = testRelation - .select(Upper(Upper('a)) as 'u) + .select(Upper(Upper(Symbol("a"))) as Symbol("u")) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select(Upper('a) as 'u) + .select(Upper(Symbol("a")) as Symbol("u")) .analyze comparePlans(optimized, correctAnswer) @@ -51,12 +51,12 @@ class SimplifyStringCaseConversionSuite extends PlanTest { test("simplify UPPER(LOWER(str))") { val originalQuery = testRelation - .select(Upper(Lower('a)) as 'u) + .select(Upper(Lower(Symbol("a"))) as Symbol("u")) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select(Upper('a) as 'u) + .select(Upper(Symbol("a")) as Symbol("u")) .analyze comparePlans(optimized, correctAnswer) @@ -65,11 +65,11 @@ class SimplifyStringCaseConversionSuite extends PlanTest { test("simplify LOWER(UPPER(str))") { val originalQuery = testRelation - .select(Lower(Upper('a)) as 'l) + .select(Lower(Upper(Symbol("a"))) as Symbol("l")) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select(Lower('a) as 'l) + .select(Lower(Symbol("a")) as Symbol("l")) .analyze comparePlans(optimized, correctAnswer) @@ -78,11 +78,11 @@ class SimplifyStringCaseConversionSuite extends PlanTest { test("simplify LOWER(LOWER(str))") { val originalQuery = testRelation - .select(Lower(Lower('a)) as 'l) + .select(Lower(Lower(Symbol("a"))) as Symbol("l")) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select(Lower('a) as 'l) + .select(Lower(Symbol("a")) as Symbol("l")) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala index 4acd57832d2f..7810c221a85e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala @@ -31,7 +31,8 @@ class TransposeWindowSuite extends PlanTest { Batch("FlipWindow", Once, CollapseWindow, TransposeWindow) :: Nil } - val testRelation = LocalRelation('a.string, 'b.string, 'c.int, 'd.string) + val testRelation = + LocalRelation(Symbol("a").string, Symbol("b").string, Symbol("c").int, Symbol("d").string) val a = testRelation.output(0) val b = testRelation.output(1) @@ -48,40 +49,42 @@ class TransposeWindowSuite extends PlanTest { test("transpose two adjacent windows with compatible partitions") { val query = testRelation - .window(Seq(sum(c).as('sum_a_2)), partitionSpec2, orderSpec2) - .window(Seq(sum(c).as('sum_a_1)), partitionSpec1, orderSpec1) + .window(Seq(sum(c).as(Symbol("sum_a_2"))), partitionSpec2, orderSpec2) + .window(Seq(sum(c).as(Symbol("sum_a_1"))), partitionSpec1, orderSpec1) val analyzed = query.analyze val optimized = Optimize.execute(analyzed) val correctAnswer = testRelation - .window(Seq(sum(c).as('sum_a_1)), partitionSpec1, orderSpec1) - .window(Seq(sum(c).as('sum_a_2)), partitionSpec2, orderSpec2) - .select('a, 'b, 'c, 'd, 'sum_a_2, 'sum_a_1) + .window(Seq(sum(c).as(Symbol("sum_a_1"))), partitionSpec1, orderSpec1) + .window(Seq(sum(c).as(Symbol("sum_a_2"))), partitionSpec2, orderSpec2) + .select(Symbol("a"), Symbol("b"), Symbol("c"), + Symbol("d"), Symbol("sum_a_2"), Symbol("sum_a_1")) comparePlans(optimized, correctAnswer.analyze) } test("transpose two adjacent windows with differently ordered compatible partitions") { val query = testRelation - .window(Seq(sum(c).as('sum_a_2)), partitionSpec4, Seq.empty) - .window(Seq(sum(c).as('sum_a_1)), partitionSpec2, Seq.empty) + .window(Seq(sum(c).as(Symbol("sum_a_2"))), partitionSpec4, Seq.empty) + .window(Seq(sum(c).as(Symbol("sum_a_1"))), partitionSpec2, Seq.empty) val analyzed = query.analyze val optimized = Optimize.execute(analyzed) val correctAnswer = testRelation - .window(Seq(sum(c).as('sum_a_1)), partitionSpec2, Seq.empty) - .window(Seq(sum(c).as('sum_a_2)), partitionSpec4, Seq.empty) - .select('a, 'b, 'c, 'd, 'sum_a_2, 'sum_a_1) + .window(Seq(sum(c).as(Symbol("sum_a_1"))), partitionSpec2, Seq.empty) + .window(Seq(sum(c).as(Symbol("sum_a_2"))), partitionSpec4, Seq.empty) + .select(Symbol("a"), Symbol("b"), Symbol("c"), + Symbol("d"), Symbol("sum_a_2"), Symbol("sum_a_1")) comparePlans(optimized, correctAnswer.analyze) } test("don't transpose two adjacent windows with incompatible partitions") { val query = testRelation - .window(Seq(sum(c).as('sum_a_2)), partitionSpec3, Seq.empty) - .window(Seq(sum(c).as('sum_a_1)), partitionSpec1, Seq.empty) + .window(Seq(sum(c).as(Symbol("sum_a_2"))), partitionSpec3, Seq.empty) + .window(Seq(sum(c).as(Symbol("sum_a_1"))), partitionSpec1, Seq.empty) val analyzed = query.analyze val optimized = Optimize.execute(analyzed) @@ -91,8 +94,9 @@ class TransposeWindowSuite extends PlanTest { test("don't transpose two adjacent windows with intersection of partition and output set") { val query = testRelation - .window(Seq(('a + 'b).as('e), sum(c).as('sum_a_2)), partitionSpec3, Seq.empty) - .window(Seq(sum(c).as('sum_a_1)), Seq(a, 'e), Seq.empty) + .window(Seq((Symbol("a") + Symbol("b")).as(Symbol("e")), + sum(c).as(Symbol("sum_a_2"))), partitionSpec3, Seq.empty) + .window(Seq(sum(c).as(Symbol("sum_a_1"))), Seq(a, Symbol("e")), Seq.empty) val analyzed = query.analyze val optimized = Optimize.execute(analyzed) @@ -102,8 +106,9 @@ class TransposeWindowSuite extends PlanTest { test("don't transpose two adjacent windows with non-deterministic expressions") { val query = testRelation - .window(Seq(Rand(0).as('e), sum(c).as('sum_a_2)), partitionSpec3, Seq.empty) - .window(Seq(sum(c).as('sum_a_1)), partitionSpec1, Seq.empty) + .window(Seq(Rand(0).as(Symbol("e")), + sum(c).as(Symbol("sum_a_2"))), partitionSpec3, Seq.empty) + .window(Seq(sum(c).as(Symbol("sum_a_1"))), partitionSpec1, Seq.empty) val analyzed = query.analyze val optimized = Optimize.execute(analyzed) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala index 5fc99a3a57c0..678ba3be339c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -39,7 +39,7 @@ class TypedFilterOptimizationSuite extends PlanTest { implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() - val testRelation = LocalRelation('_1.int, '_2.int) + val testRelation = LocalRelation(Symbol("_1").int, Symbol("_2").int) test("filter after serialize with the same object type") { val f = (i: (Int, Int)) => i._1 > 0 @@ -53,7 +53,7 @@ class TypedFilterOptimizationSuite extends PlanTest { val expected = testRelation .deserialize[(Int, Int)] - .where(callFunction(f, BooleanType, 'obj)) + .where(callFunction(f, BooleanType, Symbol("obj"))) .serialize[(Int, Int)].analyze comparePlans(optimized, expected) @@ -82,7 +82,7 @@ class TypedFilterOptimizationSuite extends PlanTest { val expected = testRelation .deserialize[(Int, Int)] - .where(callFunction(f, BooleanType, 'obj)) + .where(callFunction(f, BooleanType, Symbol("obj"))) .serialize[(Int, Int)].analyze comparePlans(optimized, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala index 0afb166b80ca..3820fe439c91 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparisonSuite.scala @@ -36,10 +36,11 @@ class UnwrapCastInBinaryComparisonSuite extends PlanTest with ExpressionEvalHelp NullPropagation, UnwrapCastInBinaryComparison) :: Nil } - val testRelation: LocalRelation = LocalRelation('a.short, 'b.float, 'c.decimal(5, 2)) - val f: BoundReference = 'a.short.canBeNull.at(0) - val f2: BoundReference = 'b.float.canBeNull.at(1) - val f3: BoundReference = 'c.decimal(5, 2).canBeNull.at(2) + val testRelation: LocalRelation = + LocalRelation(Symbol("a").short, Symbol("b").float, Symbol("c").decimal(5, 2)) + val f: BoundReference = Symbol("a").short.canBeNull.at(0) + val f2: BoundReference = Symbol("b").float.canBeNull.at(1) + val f3: BoundReference = Symbol("c").decimal(5, 2).canBeNull.at(2) test("unwrap casts when literal == max") { val v = Short.MaxValue diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index dcd2fbbf0052..ba2a44253061 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -48,11 +48,12 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { SimplifyExtractValueOps) :: Nil } - private val idAtt = ('id).long.notNull - private val nullableIdAtt = ('nullable_id).long + private val idAtt = (Symbol("id")).long.notNull + private val nullableIdAtt = (Symbol("nullable_id")).long private val relation = LocalRelation(idAtt, nullableIdAtt) - private val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.double, 'e.int) + private val testRelation = LocalRelation( + Symbol("a").int, Symbol("b").int, Symbol("c").int, Symbol("d").double, Symbol("e").int) private def checkRule(originalQuery: LogicalPlan, correctAnswer: LogicalPlan) = { val optimized = Optimizer.execute(originalQuery.analyze) @@ -64,29 +65,29 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { val query = relation .select( GetStructField( - CreateNamedStruct(Seq("att", 'id )), + CreateNamedStruct(Seq("att", Symbol("id") )), 0, None) as "outerAtt") - val expected = relation.select('id as "outerAtt") + val expected = relation.select(Symbol("id") as "outerAtt") checkRule(query, expected) } test("explicit get from named_struct- expression maintains original deduced alias") { val query = relation - .select(GetStructField(CreateNamedStruct(Seq("att", 'id)), 0, None)) + .select(GetStructField(CreateNamedStruct(Seq("att", Symbol("id"))), 0, None)) val expected = relation - .select('id as "named_struct(att, id).att") + .select(Symbol("id") as "named_struct(att, id).att") checkRule(query, expected) } test("collapsed getStructField ontop of namedStruct") { val query = relation - .select(CreateNamedStruct(Seq("att", 'id)) as "struct1") - .select(GetStructField('struct1, 0, None) as "struct1Att") - val expected = relation.select('id as "struct1Att") + .select(CreateNamedStruct(Seq("att", Symbol("id"))) as "struct1") + .select(GetStructField(Symbol("struct1"), 0, None) as "struct1Att") + val expected = relation.select(Symbol("id") as "struct1Att") checkRule(query, expected) } @@ -94,17 +95,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { val query = relation .select( CreateNamedStruct(Seq( - "att1", 'id, - "att2", 'id * 'id)) as "struct1") + "att1", Symbol("id"), + "att2", Symbol("id") * Symbol("id"))) as "struct1") .select( - GetStructField('struct1, 0, None) as "struct1Att1", - GetStructField('struct1, 1, None) as "struct1Att2") + GetStructField(Symbol("struct1"), 0, None) as "struct1Att1", + GetStructField(Symbol("struct1"), 1, None) as "struct1Att2") val expected = relation. select( - 'id as "struct1Att1", - ('id * 'id) as "struct1Att2") + Symbol("id") as "struct1Att1", + (Symbol("id") * Symbol("id")) as "struct1Att2") checkRule(query, expected) } @@ -113,17 +114,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { val query = relation .select( CreateNamedStruct(Seq( - "att1", 'id, - "att2", 'id * 'id)) as "struct1") + "att1", Symbol("id"), + "att2", Symbol("id") * Symbol("id"))) as "struct1") .select( - GetStructField('struct1, 0, None), - GetStructField('struct1, 1, None)) + GetStructField(Symbol("struct1"), 0, None), + GetStructField(Symbol("struct1"), 1, None)) val expected = relation. select( - 'id as "struct1.att1", - ('id * 'id) as "struct1.att2") + Symbol("id") as "struct1.att1", + (Symbol("id") * Symbol("id")) as "struct1.att2") checkRule(query, expected) } @@ -132,21 +133,22 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { val rel = relation.select( CreateArray(Seq( CreateNamedStruct(Seq( - "att1", 'id, - "att2", 'id * 'id)), + "att1", Symbol("id"), + "att2", Symbol("id") * Symbol("id"))), CreateNamedStruct(Seq( - "att1", 'id + 1, - "att2", ('id + 1) * ('id + 1)) + "att1", Symbol("id") + 1, + "att2", (Symbol("id") + 1) * (Symbol("id") + 1)) )) ) as "arr" ) val query = rel .select( - GetArrayStructFields('arr, StructField("att1", LongType, false), 0, 1, false) as "a1", - GetArrayItem('arr, 1) as "a2", - GetStructField(GetArrayItem('arr, 1), 0, None) as "a3", + GetArrayStructFields( + Symbol("arr"), StructField("att1", LongType, false), 0, 1, false) as "a1", + GetArrayItem(Symbol("arr"), 1) as "a2", + GetStructField(GetArrayItem(Symbol("arr"), 1), 0, None) as "a3", GetArrayItem( - GetArrayStructFields('arr, + GetArrayStructFields(Symbol("arr"), StructField("att1", LongType, false), 0, 1, @@ -155,12 +157,12 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { val expected = relation .select( - CreateArray(Seq('id, 'id + 1L)) as "a1", + CreateArray(Seq(Symbol("id"), Symbol("id") + 1L)) as "a1", CreateNamedStruct(Seq( - "att1", ('id + 1L), - "att2", (('id + 1L) * ('id + 1L)))) as "a2", - ('id + 1L) as "a3", - ('id + 1L) as "a4") + "att1", (Symbol("id") + 1L), + "att2", ((Symbol("id") + 1L) * (Symbol("id") + 1L)))) as "a2", + (Symbol("id") + 1L) as "a3", + (Symbol("id") + 1L) as "a4") checkRule(query, expected) } @@ -179,19 +181,19 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { val rel = relation .select( CreateMap(Seq( - "r1", CreateNamedStruct(Seq("att1", 'id)), - "r2", CreateNamedStruct(Seq("att1", ('id + 1L))))) as "m") + "r1", CreateNamedStruct(Seq("att1", Symbol("id"))), + "r2", CreateNamedStruct(Seq("att1", (Symbol("id") + 1L))))) as "m") val query = rel .select( - GetMapValue('m, "r1") as "a1", - GetStructField(GetMapValue('m, "r1"), 0, None) as "a2", - GetMapValue('m, "r32") as "a3", - GetStructField(GetMapValue('m, "r32"), 0, None) as "a4") + GetMapValue(Symbol("m"), "r1") as "a1", + GetStructField(GetMapValue(Symbol("m"), "r1"), 0, None) as "a2", + GetMapValue(Symbol("m"), "r32") as "a3", + GetStructField(GetMapValue(Symbol("m"), "r32"), 0, None) as "a4") val expected = relation.select( - CreateNamedStruct(Seq("att1", 'id)) as "a1", - 'id as "a2", + CreateNamedStruct(Seq("att1", Symbol("id"))) as "a1", + Symbol("id") as "a2", Literal.create( null, StructType( @@ -206,21 +208,21 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { val query = relation.select( GetMapValue( CreateMap(Seq( - 'id, ('id + 1L), - ('id + 1L), ('id + 2L), - ('id + 2L), ('id + 3L), - Literal(13L), 'id, - ('id + 3L), ('id + 4L), - ('id + 4L), ('id + 5L))), + Symbol("id"), (Symbol("id") + 1L), + (Symbol("id") + 1L), (Symbol("id") + 2L), + (Symbol("id") + 2L), (Symbol("id") + 3L), + Literal(13L), Symbol("id"), + (Symbol("id") + 3L), (Symbol("id") + 4L), + (Symbol("id") + 4L), (Symbol("id") + 5L))), 13L) as "a") val expected = relation .select( CaseWhen(Seq( - (EqualTo(13L, 'id), ('id + 1L)), - (EqualTo(13L, ('id + 1L)), ('id + 2L)), - (EqualTo(13L, ('id + 2L)), ('id + 3L)), - (Literal(true), 'id))) as "a") + (EqualTo(13L, Symbol("id")), (Symbol("id") + 1L)), + (EqualTo(13L, (Symbol("id") + 1L)), (Symbol("id") + 2L)), + (EqualTo(13L, (Symbol("id") + 2L)), (Symbol("id") + 3L)), + (Literal(true), Symbol("id")))) as "a") checkRule(query, expected) } @@ -229,19 +231,19 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( GetMapValue( CreateMap(Seq( - 'id, ('id + 1L), - ('id + 1L), ('id + 2L), - ('id + 2L), ('id + 3L), - ('id + 3L), ('id + 4L), - ('id + 4L), ('id + 5L))), - ('id + 3L)) as "a") + Symbol("id"), (Symbol("id") + 1L), + (Symbol("id") + 1L), (Symbol("id") + 2L), + (Symbol("id") + 2L), (Symbol("id") + 3L), + (Symbol("id") + 3L), (Symbol("id") + 4L), + (Symbol("id") + 4L), (Symbol("id") + 5L))), + (Symbol("id") + 3L)) as "a") val expected = relation .select( CaseWhen(Seq( - (EqualTo('id + 3L, 'id), ('id + 1L)), - (EqualTo('id + 3L, ('id + 1L)), ('id + 2L)), - (EqualTo('id + 3L, ('id + 2L)), ('id + 3L)), - (Literal(true), ('id + 4L)))) as "a") + (EqualTo(Symbol("id") + 3L, Symbol("id")), (Symbol("id") + 1L)), + (EqualTo(Symbol("id") + 3L, (Symbol("id") + 1L)), (Symbol("id") + 2L)), + (EqualTo(Symbol("id") + 3L, (Symbol("id") + 2L)), (Symbol("id") + 3L)), + (Literal(true), (Symbol("id") + 4L)))) as "a") checkRule(query, expected) } @@ -250,19 +252,19 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( GetMapValue( CreateMap(Seq( - 'id, ('id + 1L), - ('id + 1L), ('id + 2L), - ('id + 2L), ('id + 3L), - ('id + 3L), ('id + 4L), - ('id + 4L), ('id + 5L))), - 'id + 30L) as "a") + Symbol("id"), (Symbol("id") + 1L), + (Symbol("id") + 1L), (Symbol("id") + 2L), + (Symbol("id") + 2L), (Symbol("id") + 3L), + (Symbol("id") + 3L), (Symbol("id") + 4L), + (Symbol("id") + 4L), (Symbol("id") + 5L))), + Symbol("id") + 30L) as "a") val expected = relation.select( CaseWhen(Seq( - (EqualTo('id + 30L, 'id), ('id + 1L)), - (EqualTo('id + 30L, ('id + 1L)), ('id + 2L)), - (EqualTo('id + 30L, ('id + 2L)), ('id + 3L)), - (EqualTo('id + 30L, ('id + 3L)), ('id + 4L)), - (EqualTo('id + 30L, ('id + 4L)), ('id + 5L)))) as "a") + (EqualTo(Symbol("id") + 30L, Symbol("id")), (Symbol("id") + 1L)), + (EqualTo(Symbol("id") + 30L, (Symbol("id") + 1L)), (Symbol("id") + 2L)), + (EqualTo(Symbol("id") + 30L, (Symbol("id") + 2L)), (Symbol("id") + 3L)), + (EqualTo(Symbol("id") + 30L, (Symbol("id") + 3L)), (Symbol("id") + 4L)), + (EqualTo(Symbol("id") + 30L, (Symbol("id") + 4L)), (Symbol("id") + 5L)))) as "a") checkRule(rel, expected) } @@ -271,22 +273,22 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( GetMapValue( CreateMap(Seq( - 'id, ('id + 1L), - ('id + 1L), ('id + 2L), - ('id + 2L), ('id + 3L), - Literal(14L), 'id, - ('id + 3L), ('id + 4L), - ('id + 4L), ('id + 5L))), + Symbol("id"), (Symbol("id") + 1L), + (Symbol("id") + 1L), (Symbol("id") + 2L), + (Symbol("id") + 2L), (Symbol("id") + 3L), + Literal(14L), Symbol("id"), + (Symbol("id") + 3L), (Symbol("id") + 4L), + (Symbol("id") + 4L), (Symbol("id") + 5L))), 13L) as "a") val expected = relation .select( CaseKeyWhen(13L, - Seq('id, ('id + 1L), - ('id + 1L), ('id + 2L), - ('id + 2L), ('id + 3L), - ('id + 3L), ('id + 4L), - ('id + 4L), ('id + 5L))) as "a") + Seq(Symbol("id"), (Symbol("id") + 1L), + (Symbol("id") + 1L), (Symbol("id") + 2L), + (Symbol("id") + 2L), (Symbol("id") + 3L), + (Symbol("id") + 3L), (Symbol("id") + 4L), + (Symbol("id") + 4L), (Symbol("id") + 5L))) as "a") checkRule(rel, expected) } @@ -296,100 +298,102 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { .select( GetMapValue( CreateMap(Seq( - 'id, ('id + 1L), - ('id + 1L), ('id + 2L), - ('id + 2L), Literal.create(null, LongType), - Literal(2L), 'id, - ('id + 3L), ('id + 4L), - ('id + 4L), ('id + 5L))), + Symbol("id"), (Symbol("id") + 1L), + (Symbol("id") + 1L), (Symbol("id") + 2L), + (Symbol("id") + 2L), Literal.create(null, LongType), + Literal(2L), Symbol("id"), + (Symbol("id") + 3L), (Symbol("id") + 4L), + (Symbol("id") + 4L), (Symbol("id") + 5L))), 2L ) as "a") val expected = relation .select( CaseWhen(Seq( - (EqualTo(2L, 'id), ('id + 1L)), + (EqualTo(2L, Symbol("id")), (Symbol("id") + 1L)), // these two are possible matches, we can't tell until runtime - (EqualTo(2L, ('id + 1L)), ('id + 2L)), - (EqualTo(2L, 'id + 2L), Literal.create(null, LongType)), + (EqualTo(2L, (Symbol("id") + 1L)), (Symbol("id") + 2L)), + (EqualTo(2L, Symbol("id") + 2L), Literal.create(null, LongType)), // this is a definite match (two constants), // but it cannot override a potential match with ('id + 2L), // which is exactly what [[Coalesce]] would do in this case. - (Literal.TrueLiteral, 'id))) as "a") + (Literal.TrueLiteral, Symbol("id")))) as "a") checkRule(rel, expected) } test("SPARK-23500: Simplify array ops that are not at the top node") { - val query = LocalRelation('id.long) + val query = LocalRelation(Symbol("id").long) .select( CreateArray(Seq( CreateNamedStruct(Seq( - "att1", 'id, - "att2", 'id * 'id)), + "att1", Symbol("id"), + "att2", Symbol("id") * Symbol("id"))), CreateNamedStruct(Seq( - "att1", 'id + 1, - "att2", ('id + 1) * ('id + 1)) + "att1", Symbol("id") + 1, + "att2", (Symbol("id") + 1) * (Symbol("id") + 1)) )) ) as "arr") .select( - GetStructField(GetArrayItem('arr, 1), 0, None) as "a1", + GetStructField(GetArrayItem(Symbol("arr"), 1), 0, None) as "a1", GetArrayItem( - GetArrayStructFields('arr, + GetArrayStructFields(Symbol("arr"), StructField("att1", LongType, nullable = false), ordinal = 0, numFields = 1, containsNull = false), ordinal = 1) as "a2") - .orderBy('id.asc) + .orderBy(Symbol("id").asc) - val expected = LocalRelation('id.long) + val expected = LocalRelation(Symbol("id").long) .select( - ('id + 1L) as "a1", - ('id + 1L) as "a2") - .orderBy('id.asc) + (Symbol("id") + 1L) as "a1", + (Symbol("id") + 1L) as "a2") + .orderBy(Symbol("id").asc) checkRule(query, expected) } test("SPARK-23500: Simplify map ops that are not top nodes") { val query = - LocalRelation('id.long) + LocalRelation(Symbol("id").long) .select( CreateMap(Seq( - "r1", 'id, - "r2", 'id + 1L)) as "m") + "r1", Symbol("id"), + "r2", Symbol("id") + 1L)) as "m") .select( - GetMapValue('m, "r1") as "a1", - GetMapValue('m, "r32") as "a2") - .orderBy('id.asc) - .select('a1, 'a2) + GetMapValue(Symbol("m"), "r1") as "a1", + GetMapValue(Symbol("m"), "r32") as "a2") + .orderBy(Symbol("id").asc) + .select(Symbol("a1"), Symbol("a2")) val expected = - LocalRelation('id.long).select( - 'id as "a1", + LocalRelation(Symbol("id").long).select( + Symbol("id") as "a1", Literal.create(null, LongType) as "a2") - .orderBy('id.asc) + .orderBy(Symbol("id").asc) checkRule(query, expected) } test("SPARK-23500: Simplify complex ops that aren't at the plan root") { val structRel = relation - .select(GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None) as "foo") + .select(GetStructField( + CreateNamedStruct(Seq("att1", Symbol("nullable_id"))), 0, None) as "foo") .groupBy($"foo")("1") val structExpected = relation - .select('nullable_id as "foo") + .select(Symbol("nullable_id") as "foo") .groupBy($"foo")("1") checkRule(structRel, structExpected) val arrayRel = relation - .select(GetArrayItem(CreateArray(Seq('nullable_id, 'nullable_id + 1L)), 0) as "a1") + .select(GetArrayItem( + CreateArray(Seq(Symbol("nullable_id"), Symbol("nullable_id") + 1L)), 0) as "a1") .groupBy($"a1")("1") - val arrayExpected = relation.select('nullable_id as "a1").groupBy($"a1")("1") + val arrayExpected = relation.select(Symbol("nullable_id") as "a1").groupBy($"a1")("1") checkRule(arrayRel, arrayExpected) val mapRel = relation - .select(GetMapValue(CreateMap(Seq("id", 'nullable_id)), "id") as "m1") + .select(GetMapValue(CreateMap(Seq("id", Symbol("nullable_id"))), "id") as "m1") .groupBy($"m1")("1") val mapExpected = relation - .select('nullable_id as "m1") + .select(Symbol("nullable_id") as "m1") .groupBy($"m1")("1") checkRule(mapRel, mapExpected) } @@ -398,19 +402,20 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // Make sure that aggregation exprs are correctly ignored. Maps can't be used in // grouping exprs so aren't tested here. val structAggRel = relation.groupBy( - CreateNamedStruct(Seq("att1", 'nullable_id)))( - GetStructField(CreateNamedStruct(Seq("att1", 'nullable_id)), 0, None)) + CreateNamedStruct(Seq("att1", Symbol("nullable_id"))))( + GetStructField(CreateNamedStruct(Seq("att1", Symbol("nullable_id"))), 0, None)) checkRule(structAggRel, structAggRel) val arrayAggRel = relation.groupBy( - CreateArray(Seq('nullable_id)))(GetArrayItem(CreateArray(Seq('nullable_id)), 0)) + CreateArray( + Seq(Symbol("nullable_id"))))(GetArrayItem(CreateArray(Seq(Symbol("nullable_id"))), 0)) checkRule(arrayAggRel, arrayAggRel) // This could be done if we had a more complex rule that checks that // the CreateMap does not come from key. val originalQuery = relation - .groupBy('id)( - GetMapValue(CreateMap(Seq('id, 'id + 1L)), 0L) as "a" + .groupBy(Symbol("id"))( + GetMapValue(CreateMap(Seq(Symbol("id"), Symbol("id") + 1L)), 0L) as "a" ) checkRule(originalQuery, originalQuery) } @@ -419,13 +424,13 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { val originalQuery = testRelation .select( - namedStruct("col1", 'b, "col2", 'c).as("s1"), 'a, 'b) - .select('s1 getField "col2" as 's1Col2, - namedStruct("col1", 'a, "col2", 'b).as("s2")) - .select('s1Col2, 's2 getField "col2" as 's2Col2) + namedStruct("col1", Symbol("b"), "col2", Symbol("c")).as("s1"), Symbol("a"), Symbol("b")) + .select(Symbol("s1") getField "col2" as Symbol("s1Col2"), + namedStruct("col1", Symbol("a"), "col2", Symbol("b")).as("s2")) + .select(Symbol("s1Col2"), Symbol("s2") getField "col2" as Symbol("s2Col2")) val correctAnswer = testRelation - .select('c as 's1Col2, 'b as 's2Col2) + .select(Symbol("c") as Symbol("s1Col2"), Symbol("b") as Symbol("s2Col2")) checkRule(originalQuery, correctAnswer) } @@ -433,11 +438,11 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { val originalQuery = testRelation .select( - namedStruct("col1", 'b, "col2", 'c) getField "col2" as 'sCol2, - namedStruct("col1", 'a, "col2", 'c) getField "col1" as 'sCol1) + namedStruct("col1", Symbol("b"), "col2", Symbol("c")) getField "col2" as Symbol("sCol2"), + namedStruct("col1", Symbol("a"), "col2", Symbol("c")) getField "col1" as Symbol("sCol1")) val correctAnswer = testRelation - .select('c as 'sCol2, 'a as 'sCol1) + .select(Symbol("c") as Symbol("sCol2"), Symbol("a") as Symbol("sCol1")) checkRule(originalQuery, correctAnswer) } @@ -454,16 +459,18 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null) } - private val structAttr = 'struct1.struct('a.int, 'b.int).withNullability(false) + private val structAttr = + Symbol("struct1").struct(Symbol("a").int, Symbol("b").int).withNullability(false) private val testStructRelation = LocalRelation(structAttr) - private val nullableStructAttr = 'struct1.struct('a.int, 'b.int) + private val nullableStructAttr = Symbol("struct1").struct(Symbol("a").int, Symbol("b").int) private val testNullableStructRelation = LocalRelation(nullableStructAttr) test("simplify GetStructField on basic UpdateFields") { def check(fieldOps: Seq[StructFieldsOperation], ordinal: Int, expected: Expression): Unit = { def query(relation: LocalRelation): LogicalPlan = - relation.select(GetStructField(UpdateFields('struct1, fieldOps), ordinal).as("res")) + relation.select( + GetStructField(UpdateFields(Symbol("struct1"), fieldOps), ordinal).as("res")) checkRule( query(testStructRelation), @@ -473,30 +480,30 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { query(testNullableStructRelation), testNullableStructRelation.select((expected match { case expr: GetStructField => expr - case expr => If(IsNull('struct1), Literal(null, expr.dataType), expr) + case expr => If(IsNull(Symbol("struct1")), Literal(null, expr.dataType), expr) }).as("res"))) } // scalastyle:off line.size.limit // add attribute, extract an attribute from the original struct - check(WithField("c", Literal(3)) :: Nil, 0, GetStructField('struct1, 0)) - check(WithField("c", Literal(3)) :: Nil, 1, GetStructField('struct1, 1)) + check(WithField("c", Literal(3)) :: Nil, 0, GetStructField(Symbol("struct1"), 0)) + check(WithField("c", Literal(3)) :: Nil, 1, GetStructField(Symbol("struct1"), 1)) // add attribute, extract added attribute check(WithField("c", Literal(3)) :: Nil, 2, Literal(3)) // replace attribute, extract an attribute from the original struct - check(WithField("a", Literal(1)) :: Nil, 1, GetStructField('struct1, 1)) - check(WithField("b", Literal(2)) :: Nil, 0, GetStructField('struct1, 0)) + check(WithField("a", Literal(1)) :: Nil, 1, GetStructField(Symbol("struct1"), 1)) + check(WithField("b", Literal(2)) :: Nil, 0, GetStructField(Symbol("struct1"), 0)) // replace attribute, extract replaced attribute check(WithField("a", Literal(1)) :: Nil, 0, Literal(1)) check(WithField("b", Literal(2)) :: Nil, 1, Literal(2)) // add multiple attributes, extract an attribute from the original struct - check(WithField("c", Literal(3)) :: WithField("c", Literal(4)) :: Nil, 0, GetStructField('struct1, 0)) - check(WithField("c", Literal(3)) :: WithField("d", Literal(4)) :: Nil, 0, GetStructField('struct1, 0)) - check(WithField("c", Literal(3)) :: WithField("c", Literal(4)) :: Nil, 1, GetStructField('struct1, 1)) - check(WithField("c", Literal(3)) :: WithField("d", Literal(4)) :: Nil, 1, GetStructField('struct1, 1)) + check(WithField("c", Literal(3)) :: WithField("c", Literal(4)) :: Nil, 0, GetStructField(Symbol("struct1"), 0)) + check(WithField("c", Literal(3)) :: WithField("d", Literal(4)) :: Nil, 0, GetStructField(Symbol("struct1"), 0)) + check(WithField("c", Literal(3)) :: WithField("c", Literal(4)) :: Nil, 1, GetStructField(Symbol("struct1"), 1)) + check(WithField("c", Literal(3)) :: WithField("d", Literal(4)) :: Nil, 1, GetStructField(Symbol("struct1"), 1)) // add multiple attributes, extract newly added attribute check(WithField("c", Literal(3)) :: WithField("c", Literal(4)) :: Nil, 2, Literal(4)) check(WithField("c", Literal(4)) :: WithField("c", Literal(3)) :: Nil, 2, Literal(3)) @@ -506,45 +513,45 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { check(WithField("d", Literal(4)) :: WithField("c", Literal(3)) :: Nil, 3, Literal(3)) // drop attribute, extract an attribute from the original struct - check(DropField("b") :: Nil, 0, GetStructField('struct1, 0)) - check(DropField("a") :: Nil, 0, GetStructField('struct1, 1)) + check(DropField("b") :: Nil, 0, GetStructField(Symbol("struct1"), 0)) + check(DropField("a") :: Nil, 0, GetStructField(Symbol("struct1"), 1)) // drop attribute, add attribute, extract an attribute from the original struct - check(DropField("b") :: WithField("c", Literal(3)) :: Nil, 0, GetStructField('struct1, 0)) - check(DropField("a") :: WithField("c", Literal(3)) :: Nil, 0, GetStructField('struct1, 1)) + check(DropField("b") :: WithField("c", Literal(3)) :: Nil, 0, GetStructField(Symbol("struct1"), 0)) + check(DropField("a") :: WithField("c", Literal(3)) :: Nil, 0, GetStructField(Symbol("struct1"), 1)) // drop attribute, add attribute, extract added attribute check(DropField("b") :: WithField("c", Literal(3)) :: Nil, 1, Literal(3)) check(DropField("a") :: WithField("c", Literal(3)) :: Nil, 1, Literal(3)) // add attribute, drop attribute, extract an attribute from the original struct - check(WithField("c", Literal(3)) :: DropField("a") :: Nil, 0, GetStructField('struct1, 1)) - check(WithField("c", Literal(3)) :: DropField("b") :: Nil, 0, GetStructField('struct1, 0)) + check(WithField("c", Literal(3)) :: DropField("a") :: Nil, 0, GetStructField(Symbol("struct1"), 1)) + check(WithField("c", Literal(3)) :: DropField("b") :: Nil, 0, GetStructField(Symbol("struct1"), 0)) // add attribute, drop attribute, extract added attribute check(WithField("c", Literal(3)) :: DropField("a") :: Nil, 1, Literal(3)) check(WithField("c", Literal(3)) :: DropField("b") :: Nil, 1, Literal(3)) // replace attribute, drop same attribute, extract an attribute from the original struct - check(WithField("b", Literal(3)) :: DropField("b") :: Nil, 0, GetStructField('struct1, 0)) - check(WithField("a", Literal(3)) :: DropField("a") :: Nil, 0, GetStructField('struct1, 1)) + check(WithField("b", Literal(3)) :: DropField("b") :: Nil, 0, GetStructField(Symbol("struct1"), 0)) + check(WithField("a", Literal(3)) :: DropField("a") :: Nil, 0, GetStructField(Symbol("struct1"), 1)) // add attribute, drop same attribute, extract an attribute from the original struct - check(WithField("c", Literal(3)) :: DropField("c") :: Nil, 0, GetStructField('struct1, 0)) - check(WithField("c", Literal(3)) :: DropField("c") :: Nil, 1, GetStructField('struct1, 1)) + check(WithField("c", Literal(3)) :: DropField("c") :: Nil, 0, GetStructField(Symbol("struct1"), 0)) + check(WithField("c", Literal(3)) :: DropField("c") :: Nil, 1, GetStructField(Symbol("struct1"), 1)) // replace attribute, drop another attribute, extract added attribute check(WithField("b", Literal(3)) :: DropField("a") :: Nil, 0, Literal(3)) check(WithField("a", Literal(3)) :: DropField("b") :: Nil, 0, Literal(3)) // drop attribute, add same attribute, extract attribute from the original struct - check(DropField("b") :: WithField("b", Literal(3)) :: Nil, 0, GetStructField('struct1, 0)) - check(DropField("a") :: WithField("a", Literal(3)) :: Nil, 0, GetStructField('struct1, 1)) + check(DropField("b") :: WithField("b", Literal(3)) :: Nil, 0, GetStructField(Symbol("struct1"), 0)) + check(DropField("a") :: WithField("a", Literal(3)) :: Nil, 0, GetStructField(Symbol("struct1"), 1)) // drop attribute, add same attribute, extract added attribute check(DropField("b") :: WithField("b", Literal(3)) :: Nil, 1, Literal(3)) check(DropField("a") :: WithField("a", Literal(3)) :: Nil, 1, Literal(3)) // drop non-existent attribute, add same attribute, extract attribute from the original struct - check(DropField("c") :: WithField("c", Literal(3)) :: Nil, 0, GetStructField('struct1, 0)) - check(DropField("c") :: WithField("c", Literal(3)) :: Nil, 1, GetStructField('struct1, 1)) + check(DropField("c") :: WithField("c", Literal(3)) :: Nil, 0, GetStructField(Symbol("struct1"), 0)) + check(DropField("c") :: WithField("c", Literal(3)) :: Nil, 1, GetStructField(Symbol("struct1"), 1)) // drop non-existent attribute, add same attribute, extract added attribute check(DropField("c") :: WithField("c", Literal(3)) :: Nil, 2, Literal(3)) @@ -552,7 +559,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { } test("simplify GetStructField that is extracting a field nested inside a struct") { - val struct2 = 'struct2.struct('b.int) + val struct2 = Symbol("struct2").struct(Symbol("b").int) val testStructRelation = LocalRelation(structAttr, struct2) val testNullableStructRelation = LocalRelation(nullableStructAttr, struct2) @@ -561,15 +568,16 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { def addFieldFromSameStructAndThenExtractIt(relation: LocalRelation): LogicalPlan = relation.select(GetStructField( - UpdateFields('struct1, WithField("b", GetStructField('struct1, 0)) :: Nil), 1).as("res")) + UpdateFields(Symbol("struct1"), WithField("b", + GetStructField(Symbol("struct1"), 0)) :: Nil), 1).as("res")) checkRule( addFieldFromSameStructAndThenExtractIt(testStructRelation), - testStructRelation.select(GetStructField('struct1, 0).as("res"))) + testStructRelation.select(GetStructField(Symbol("struct1"), 0).as("res"))) checkRule( addFieldFromSameStructAndThenExtractIt(testNullableStructRelation), - testNullableStructRelation.select(GetStructField('struct1, 0).as("res"))) + testNullableStructRelation.select(GetStructField(Symbol("struct1"), 0).as("res"))) // if the field being extracted is from a different struct than the one UpdateFields is // modifying, we must return GetStructField wrapped in If(IsNull(struct), null, GetStructField) @@ -577,16 +585,18 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { def addFieldFromAnotherStructAndThenExtractIt(relation: LocalRelation): LogicalPlan = relation.select(GetStructField( - UpdateFields('struct1, WithField("b", GetStructField('struct2, 0)) :: Nil), 1).as("res")) + UpdateFields(Symbol("struct1"), WithField("b", + GetStructField(Symbol("struct2"), 0)) :: Nil), 1).as("res")) checkRule( addFieldFromAnotherStructAndThenExtractIt(testStructRelation), - testStructRelation.select(GetStructField('struct2, 0).as("res"))) + testStructRelation.select(GetStructField(Symbol("struct2"), 0).as("res"))) checkRule( addFieldFromAnotherStructAndThenExtractIt(testNullableStructRelation), testNullableStructRelation.select( - If(IsNull('struct1), Literal(null, IntegerType), GetStructField('struct2, 0)).as("res"))) + If(IsNull(Symbol("struct1")), Literal(null, IntegerType), + GetStructField(Symbol("struct2"), 0)).as("res"))) } test("simplify GetStructField on nested UpdateFields") { @@ -596,7 +606,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { UpdateFields( UpdateFields( UpdateFields( - 'struct1, + Symbol("struct1"), WithField("c", Literal(1)) :: Nil), WithField("d", Literal(2)) :: Nil), WithField("e", Literal(3)) :: Nil), @@ -614,79 +624,79 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { checkRule( query(testNullableStructRelation, 5), testNullableStructRelation.select( - If(IsNull('struct1), Literal(null, IntegerType), Literal(4)) as "res")) + If(IsNull(Symbol("struct1")), Literal(null, IntegerType), Literal(4)) as "res")) // extract field from original struct checkRule( query(testStructRelation, 0), - testStructRelation.select(GetStructField('struct1, 0) as "res")) + testStructRelation.select(GetStructField(Symbol("struct1"), 0) as "res")) checkRule( query(testNullableStructRelation, 0), - testNullableStructRelation.select(GetStructField('struct1, 0) as "res")) + testNullableStructRelation.select(GetStructField(Symbol("struct1"), 0) as "res")) } test("simplify multiple GetStructField on the same UpdateFields") { def query(relation: LocalRelation): LogicalPlan = relation - .select(UpdateFields('struct1, WithField("b", Literal(2)) :: Nil) as "struct2") + .select(UpdateFields(Symbol("struct1"), WithField("b", Literal(2)) :: Nil) as "struct2") .select( - GetStructField('struct2, 0, Some("a")) as "struct1A", - GetStructField('struct2, 1, Some("b")) as "struct1B") + GetStructField(Symbol("struct2"), 0, Some("a")) as "struct1A", + GetStructField(Symbol("struct2"), 1, Some("b")) as "struct1B") checkRule( query(testStructRelation), testStructRelation.select( - GetStructField('struct1, 0) as "struct1A", + GetStructField(Symbol("struct1"), 0) as "struct1A", Literal(2) as "struct1B")) checkRule( query(testNullableStructRelation), testNullableStructRelation.select( - GetStructField('struct1, 0) as "struct1A", - If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as "struct1B")) + GetStructField(Symbol("struct1"), 0) as "struct1A", + If(IsNull(Symbol("struct1")), Literal(null, IntegerType), Literal(2)) as "struct1B")) } test("simplify multiple GetStructField on different UpdateFields") { def query(relation: LocalRelation): LogicalPlan = relation .select( - UpdateFields('struct1, WithField("b", Literal(2)) :: Nil) as "struct2", - UpdateFields('struct1, WithField("b", Literal(3)) :: Nil) as "struct3") + UpdateFields(Symbol("struct1"), WithField("b", Literal(2)) :: Nil) as "struct2", + UpdateFields(Symbol("struct1"), WithField("b", Literal(3)) :: Nil) as "struct3") .select( - GetStructField('struct2, 0, Some("a")) as "struct2A", - GetStructField('struct2, 1, Some("b")) as "struct2B", - GetStructField('struct3, 0, Some("a")) as "struct3A", - GetStructField('struct3, 1, Some("b")) as "struct3B") + GetStructField(Symbol("struct2"), 0, Some("a")) as "struct2A", + GetStructField(Symbol("struct2"), 1, Some("b")) as "struct2B", + GetStructField(Symbol("struct3"), 0, Some("a")) as "struct3A", + GetStructField(Symbol("struct3"), 1, Some("b")) as "struct3B") checkRule( query(testStructRelation), testStructRelation .select( - GetStructField('struct1, 0) as "struct2A", + GetStructField(Symbol("struct1"), 0) as "struct2A", Literal(2) as "struct2B", - GetStructField('struct1, 0) as "struct3A", + GetStructField(Symbol("struct1"), 0) as "struct3A", Literal(3) as "struct3B")) checkRule( query(testNullableStructRelation), testNullableStructRelation .select( - GetStructField('struct1, 0) as "struct2A", - If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as "struct2B", - GetStructField('struct1, 0) as "struct3A", - If(IsNull('struct1), Literal(null, IntegerType), Literal(3)) as "struct3B")) + GetStructField(Symbol("struct1"), 0) as "struct2A", + If(IsNull(Symbol("struct1")), Literal(null, IntegerType), Literal(2)) as "struct2B", + GetStructField(Symbol("struct1"), 0) as "struct3A", + If(IsNull(Symbol("struct1")), Literal(null, IntegerType), Literal(3)) as "struct3B")) } test("simplify add multiple nested fields to non-nullable struct") { // this scenario is possible if users add multiple nested columns to a non-nullable struct // using the Column.withField API in a non-performant way val structLevel2 = LocalRelation( - 'a1.struct( - 'a2.struct('a3.int.notNull)).notNull) + Symbol("a1").struct( + Symbol("a2").struct(Symbol("a3").int.notNull)).notNull) val query = { - val addB3toA1A2 = UpdateFields('a1, Seq(WithField("a2", - UpdateFields(GetStructField('a1, 0), Seq(WithField("b3", Literal(2))))))) + val addB3toA1A2 = UpdateFields(Symbol("a1"), Seq(WithField("a2", + UpdateFields(GetStructField(Symbol("a1"), 0), Seq(WithField("b3", Literal(2))))))) structLevel2.select( UpdateFields( @@ -696,9 +706,9 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { } val expected = structLevel2.select( - UpdateFields('a1, Seq( + UpdateFields(Symbol("a1"), Seq( // scalastyle:off line.size.limit - WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3", 2) :: WithField("c3", 3) :: Nil)) + WithField("a2", UpdateFields(GetStructField(Symbol("a1"), 0), WithField("b3", 2) :: WithField("c3", 3) :: Nil)) // scalastyle:on line.size.limit )).as("a1")) @@ -709,12 +719,12 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // this scenario is possible if users add multiple nested columns to a nullable struct // using the Column.withField API in a non-performant way val structLevel2 = LocalRelation( - 'a1.struct( - 'a2.struct('a3.int.notNull))) + Symbol("a1").struct( + Symbol("a2").struct(Symbol("a3").int.notNull))) val query = { - val addB3toA1A2 = UpdateFields('a1, Seq(WithField("a2", - UpdateFields(GetStructField('a1, 0), Seq(WithField("b3", Literal(2))))))) + val addB3toA1A2 = UpdateFields(Symbol("a1"), Seq(WithField("a2", + UpdateFields(GetStructField(Symbol("a1"), 0), Seq(WithField("b3", Literal(2))))))) structLevel2.select( UpdateFields( @@ -724,15 +734,16 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { } val expected = { - val repeatedExpr = UpdateFields(GetStructField('a1, 0), WithField("b3", Literal(2)) :: Nil) + val repeatedExpr = + UpdateFields(GetStructField(Symbol("a1"), 0), WithField("b3", Literal(2)) :: Nil) val repeatedExprDataType = StructType(Seq( StructField("a3", IntegerType, nullable = false), StructField("b3", IntegerType, nullable = false))) structLevel2.select( - UpdateFields('a1, Seq( + UpdateFields(Symbol("a1"), Seq( WithField("a2", UpdateFields( - If(IsNull('a1), Literal(null, repeatedExprDataType), repeatedExpr), + If(IsNull(Symbol("a1")), Literal(null, repeatedExprDataType), repeatedExpr), WithField("c3", Literal(3)) :: Nil)) )).as("a1")) } @@ -744,13 +755,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // this scenario is possible if users drop multiple nested columns in a non-nullable struct // using the Column.dropFields API in a non-performant way val structLevel2 = LocalRelation( - 'a1.struct( - 'a2.struct('a3.int.notNull, 'b3.int.notNull, 'c3.int.notNull).notNull + Symbol("a1").struct( + Symbol("a2").struct(Symbol("a3").int.notNull, + Symbol("b3").int.notNull, Symbol("c3").int.notNull).notNull ).notNull) val query = { - val dropA1A2B = UpdateFields('a1, Seq(WithField("a2", UpdateFields( - GetStructField('a1, 0), Seq(DropField("b3")))))) + val dropA1A2B = UpdateFields(Symbol("a1"), Seq(WithField("a2", UpdateFields( + GetStructField(Symbol("a1"), 0), Seq(DropField("b3")))))) structLevel2.select( UpdateFields( @@ -760,8 +772,9 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { } val expected = structLevel2.select( - UpdateFields('a1, Seq( - WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(DropField("b3"), DropField("c3")))) + UpdateFields(Symbol("a1"), Seq( + WithField("a2", UpdateFields(GetStructField(Symbol("a1"), 0), + Seq(DropField("b3"), DropField("c3")))) )).as("a1")) checkRule(query, expected) @@ -771,13 +784,14 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { // this scenario is possible if users drop multiple nested columns in a nullable struct // using the Column.dropFields API in a non-performant way val structLevel2 = LocalRelation( - 'a1.struct( - 'a2.struct('a3.int.notNull, 'b3.int.notNull, 'c3.int.notNull) + Symbol("a1").struct( + Symbol("a2").struct(Symbol("a3").int.notNull, + Symbol("b3").int.notNull, Symbol("c3").int.notNull) )) val query = { - val dropA1A2B = UpdateFields('a1, Seq(WithField("a2", UpdateFields( - GetStructField('a1, 0), Seq(DropField("b3")))))) + val dropA1A2B = UpdateFields(Symbol("a1"), Seq(WithField("a2", UpdateFields( + GetStructField(Symbol("a1"), 0), Seq(DropField("b3")))))) structLevel2.select( UpdateFields( @@ -787,15 +801,15 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { } val expected = { - val repeatedExpr = UpdateFields(GetStructField('a1, 0), DropField("b3") :: Nil) + val repeatedExpr = UpdateFields(GetStructField(Symbol("a1"), 0), DropField("b3") :: Nil) val repeatedExprDataType = StructType(Seq( StructField("a3", IntegerType, nullable = false), StructField("c3", IntegerType, nullable = false))) structLevel2.select( - UpdateFields('a1, Seq( + UpdateFields(Symbol("a1"), Seq( WithField("a2", UpdateFields( - If(IsNull('a1), Literal(null, repeatedExprDataType), repeatedExpr), + If(IsNull(Symbol("a1")), Literal(null, repeatedExprDataType), repeatedExpr), DropField("c3") :: Nil)) )).as("a1")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/JoinReorderSuite.scala index 2e1cf4a137e2..56ad9b6039c1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/JoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/JoinReorderSuite.scala @@ -230,15 +230,15 @@ class JoinReorderSuite extends JoinReorderPlanTestBase with StatsEstimationTestB test("SPARK-26352: join reordering should not change the order of attributes") { // This test case does not rely on CBO. // It's similar to the test case above, but catches a reordering bug that the one above doesn't - val tab1 = LocalRelation('x.int, 'y.int) - val tab2 = LocalRelation('i.int, 'j.int) - val tab3 = LocalRelation('a.int, 'b.int) + val tab1 = LocalRelation(Symbol("x").int, Symbol("y").int) + val tab2 = LocalRelation(Symbol("i").int, Symbol("j").int) + val tab3 = LocalRelation(Symbol("a").int, Symbol("b").int) val original = tab1.join(tab2, Cross) - .join(tab3, Inner, Some('a === 'x && 'b === 'i)) + .join(tab3, Inner, Some(Symbol("a") === Symbol("x") && Symbol("b") === Symbol("i"))) val expected = - tab1.join(tab3, Inner, Some('a === 'x)) - .join(tab2, Cross, Some('b === 'i)) + tab1.join(tab3, Inner, Some(Symbol("a") === Symbol("x"))) + .join(tab2, Cross, Some(Symbol("b") === Symbol("i"))) .select(outputsOf(tab1, tab2, tab3): _*) assertEqualJoinPlans(Optimize, original, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/StarJoinReorderSuite.scala index ebc12b1d82cf..442ed8cd2e36 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/StarJoinReorderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/joinReorder/StarJoinReorderSuite.scala @@ -141,7 +141,8 @@ class StarJoinReorderSuite extends JoinReorderPlanTestBase with StatsEstimationT size = Some(17), attributeStats = AttributeMap(Seq("s3_pk1", "s3_c2", "s3_c3", "s3_c4").map(nameToColInfo))) - private val d3_ns = LocalRelation('d3_fk1.int, 'd3_c2.int, 'd3_pk1.int, 'd3_c4.int) + private val d3_ns = LocalRelation(Symbol("d3_fk1").int, + Symbol("d3_c2").int, Symbol("d3_pk1").int, Symbol("d3_c4").int) private val f11 = StatsTestPlan( outputList = Seq("f11_fk1", "f11_fk2", "f11_fk3", "f11_c4").map(nameToAttr), @@ -150,7 +151,7 @@ class StarJoinReorderSuite extends JoinReorderPlanTestBase with StatsEstimationT attributeStats = AttributeMap(Seq("f11_fk1", "f11_fk2", "f11_fk3", "f11_c4") .map(nameToColInfo))) - private val subq = d3.select(sum('d3_fk1).as('col)) + private val subq = d3.select(sum(Symbol("d3_fk1")).as(Symbol("col"))) test("Test 1: Selective star-join on all dimensions") { // Star join: @@ -362,7 +363,7 @@ class StarJoinReorderSuite extends JoinReorderPlanTestBase with StatsEstimationT (nameToAttr("f1_fk3") === "col".attr)) val expected = - d3.select('d3_fk1).select(sum('d3_fk1).as('col)) + d3.select(Symbol("d3_fk1")).select(sum(Symbol("d3_fk1")).as(Symbol("col"))) .join(f1, Inner, Some(nameToAttr("f1_fk3") === "col".attr)) .join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1"))) .join(d2.where(nameToAttr("d2_c2") === 2), Inner, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala index 99051d692451..1345e1017f0b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala @@ -181,9 +181,9 @@ class ErrorParserSuite extends AnalysisTest { |ORDER BY c """.stripMargin, table("t") - .where('a - 'b > 10) - .groupBy('fake - 'breaker)('a, 'b) - .orderBy('c.asc)) + .where(Symbol("a") - Symbol("b") > 10) + .groupBy(Symbol("fake") - Symbol("breaker"))(Symbol("a"), Symbol("b")) + .orderBy(Symbol("c").asc)) intercept( """ |SELECT * FROM tab diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 0b304a799cdc..6654ace49865 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -76,17 +76,17 @@ class ExpressionParserSuite extends AnalysisTest { // NamedExpression (Alias/Multialias) test("named expressions") { // No Alias - val r0 = 'a + val r0 = Symbol("a") assertEqual("a", r0) // Single Alias. - val r1 = 'a as "b" + val r1 = Symbol("a") as "b" assertEqual("a as b", r1) assertEqual("a b", r1) // Multi-Alias - assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c"))) - assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c"))) + assertEqual("a as (b, c)", MultiAlias(Symbol("a"), Seq("b", "c"))) + assertEqual("a() (b, c)", MultiAlias(Symbol("a").function(), Seq("b", "c"))) // Numeric literals without a space between the literal qualifier and the alias, should not be // interpreted as such. An unresolved reference should be returned instead. @@ -94,23 +94,25 @@ class ExpressionParserSuite extends AnalysisTest { assertEqual("1SL", Symbol("1SL")) // Aliased star is allowed. - assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b) + assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as Symbol("b")) } test("binary logical expressions") { // And - assertEqual("a and b", 'a && 'b) + assertEqual("a and b", Symbol("a") && Symbol("b")) // Or - assertEqual("a or b", 'a || 'b) + assertEqual("a or b", Symbol("a") || Symbol("b")) // Combination And/Or check precedence - assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd)) - assertEqual("a or b or c and d", 'a || 'b || ('c && 'd)) + assertEqual("a and b or c and d", (Symbol("a") && Symbol("b")) || (Symbol("c") && Symbol("d"))) + assertEqual("a or b or c and d", Symbol("a") || Symbol("b") || (Symbol("c") && Symbol("d"))) // Multiple AND/OR get converted into a balanced tree - assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f)) - assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f)) + assertEqual("a or b or c or d or e or f", ((Symbol("a") || Symbol("b")) || Symbol("c")) || + ((Symbol("d") || Symbol("e")) || Symbol("f"))) + assertEqual("a and b and c and d and e and f", ((Symbol("a") && Symbol("b")) && + Symbol("c")) && ((Symbol("d") && Symbol("e")) && Symbol("f"))) } test("long binary logical expressions") { @@ -125,8 +127,8 @@ class ExpressionParserSuite extends AnalysisTest { } test("not expressions") { - assertEqual("not a", !'a) - assertEqual("!a", !'a) + assertEqual("not a", !Symbol("a")) + assertEqual("!a", !Symbol("a")) assertEqual("not true > true", Not(GreaterThan(true, true))) } @@ -137,64 +139,66 @@ class ExpressionParserSuite extends AnalysisTest { } test("comparison expressions") { - assertEqual("a = b", 'a === 'b) - assertEqual("a == b", 'a === 'b) - assertEqual("a <=> b", 'a <=> 'b) - assertEqual("a <> b", 'a =!= 'b) - assertEqual("a != b", 'a =!= 'b) - assertEqual("a < b", 'a < 'b) - assertEqual("a <= b", 'a <= 'b) - assertEqual("a !> b", 'a <= 'b) - assertEqual("a > b", 'a > 'b) - assertEqual("a >= b", 'a >= 'b) - assertEqual("a !< b", 'a >= 'b) + assertEqual("a = b", Symbol("a") === Symbol("b")) + assertEqual("a == b", Symbol("a") === Symbol("b")) + assertEqual("a <=> b", Symbol("a") <=> Symbol("b")) + assertEqual("a <> b", Symbol("a") =!= Symbol("b")) + assertEqual("a != b", Symbol("a") =!= Symbol("b")) + assertEqual("a < b", Symbol("a") < Symbol("b")) + assertEqual("a <= b", Symbol("a") <= Symbol("b")) + assertEqual("a !> b", Symbol("a") <= Symbol("b")) + assertEqual("a > b", Symbol("a") > Symbol("b")) + assertEqual("a >= b", Symbol("a") >= Symbol("b")) + assertEqual("a !< b", Symbol("a") >= Symbol("b")) } test("between expressions") { - assertEqual("a between b and c", 'a >= 'b && 'a <= 'c) - assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c)) + assertEqual("a between b and c", Symbol("a") >= Symbol("b") && Symbol("a") <= Symbol("c")) + assertEqual("a not between b and c", + !(Symbol("a") >= Symbol("b") && Symbol("a") <= Symbol("c"))) } test("in expressions") { - assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd)) - assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd))) + assertEqual("a in (b, c, d)", Symbol("a") in (Symbol("b"), Symbol("c"), Symbol("d"))) + assertEqual("a not in (b, c, d)", !(Symbol("a") in (Symbol("b"), Symbol("c"), Symbol("d")))) } test("in sub-query") { assertEqual( "a in (select b from c)", - InSubquery(Seq('a), ListQuery(table("c").select('b)))) + InSubquery(Seq(Symbol("a")), ListQuery(table("c").select(Symbol("b"))))) assertEqual( "(a, b, c) in (select d, e, f from g)", - InSubquery(Seq('a, 'b, 'c), ListQuery(table("g").select('d, 'e, 'f)))) + InSubquery(Seq(Symbol("a"), Symbol("b"), Symbol("c")), + ListQuery(table("g").select(Symbol("d"), Symbol("e"), Symbol("f"))))) assertEqual( "(a, b) in (select c from d)", - InSubquery(Seq('a, 'b), ListQuery(table("d").select('c)))) + InSubquery(Seq(Symbol("a"), Symbol("b")), ListQuery(table("d").select(Symbol("c"))))) assertEqual( "(a) in (select b from c)", - InSubquery(Seq('a), ListQuery(table("c").select('b)))) + InSubquery(Seq(Symbol("a")), ListQuery(table("c").select(Symbol("b"))))) } test("like expressions") { - assertEqual("a like 'pattern%'", 'a like "pattern%") - assertEqual("a not like 'pattern%'", !('a like "pattern%")) - assertEqual("a rlike 'pattern%'", 'a rlike "pattern%") - assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%")) - assertEqual("a regexp 'pattern%'", 'a rlike "pattern%") - assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) + assertEqual("a like 'pattern%'", Symbol("a") like "pattern%") + assertEqual("a not like 'pattern%'", !(Symbol("a") like "pattern%")) + assertEqual("a rlike 'pattern%'", Symbol("a") rlike "pattern%") + assertEqual("a not rlike 'pattern%'", !(Symbol("a") rlike "pattern%")) + assertEqual("a regexp 'pattern%'", Symbol("a") rlike "pattern%") + assertEqual("a not regexp 'pattern%'", !(Symbol("a") rlike "pattern%")) } test("like escape expressions") { val message = "Escape string must contain only one character." - assertEqual("a like 'pattern%' escape '#'", 'a.like("pattern%", '#')) - assertEqual("a like 'pattern%' escape '\"'", 'a.like("pattern%", '\"')) + assertEqual("a like 'pattern%' escape '#'", Symbol("a").like("pattern%", '#')) + assertEqual("a like 'pattern%' escape '\"'", Symbol("a").like("pattern%", '\"')) intercept("a like 'pattern%' escape '##'", message) intercept("a like 'pattern%' escape ''", message) - assertEqual("a not like 'pattern%' escape '#'", !('a.like("pattern%", '#'))) - assertEqual("a not like 'pattern%' escape '\"'", !('a.like("pattern%", '\"'))) + assertEqual("a not like 'pattern%' escape '#'", !(Symbol("a").like("pattern%", '#'))) + assertEqual("a not like 'pattern%' escape '\"'", !(Symbol("a").like("pattern%", '\"'))) intercept("a not like 'pattern%' escape '\"/'", message) intercept("a not like 'pattern%' escape ''", message) } @@ -202,21 +206,22 @@ class ExpressionParserSuite extends AnalysisTest { test("like expressions with ESCAPED_STRING_LITERALS = true") { withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> "true") { val parser = new CatalystSqlParser() - assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser) - assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser) - assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser) + assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", + Symbol("a") rlike "^\\x20[\\x20-\\x23]+$", parser) + assertEqual("a rlike 'pattern\\\\'", Symbol("a") rlike "pattern\\\\", parser) + assertEqual("a rlike 'pattern\\t\\n'", Symbol("a") rlike "pattern\\t\\n", parser) } } test("(NOT) LIKE (ANY | SOME | ALL) expressions") { Seq("any", "some").foreach { quantifier => - assertEqual(s"a like $quantifier ('foo%', 'b%')", 'a likeAny("foo%", "b%")) - assertEqual(s"a not like $quantifier ('foo%', 'b%')", 'a notLikeAny("foo%", "b%")) - assertEqual(s"not (a like $quantifier ('foo%', 'b%'))", !('a likeAny("foo%", "b%"))) + assertEqual(s"a like $quantifier ('foo%', 'b%')", Symbol("a") likeAny("foo%", "b%")) + assertEqual(s"a not like $quantifier ('foo%', 'b%')", Symbol("a") notLikeAny("foo%", "b%")) + assertEqual(s"not (a like $quantifier ('foo%', 'b%'))", !(Symbol("a") likeAny("foo%", "b%"))) } - assertEqual("a like all ('foo%', 'b%')", 'a likeAll("foo%", "b%")) - assertEqual("a not like all ('foo%', 'b%')", 'a notLikeAll("foo%", "b%")) - assertEqual("not (a like all ('foo%', 'b%'))", !('a likeAll("foo%", "b%"))) + assertEqual("a like all ('foo%', 'b%')", Symbol("a") likeAll("foo%", "b%")) + assertEqual("a not like all ('foo%', 'b%')", Symbol("a") notLikeAll("foo%", "b%")) + assertEqual("not (a like all ('foo%', 'b%'))", !(Symbol("a") likeAll("foo%", "b%"))) Seq("any", "some", "all").foreach { quantifier => intercept(s"a like $quantifier()", "Expected something between '(' and ')'") @@ -224,73 +229,76 @@ class ExpressionParserSuite extends AnalysisTest { } test("is null expressions") { - assertEqual("a is null", 'a.isNull) - assertEqual("a is not null", 'a.isNotNull) - assertEqual("a = b is null", ('a === 'b).isNull) - assertEqual("a = b is not null", ('a === 'b).isNotNull) + assertEqual("a is null", Symbol("a").isNull) + assertEqual("a is not null", Symbol("a").isNotNull) + assertEqual("a = b is null", (Symbol("a") === Symbol("b")).isNull) + assertEqual("a = b is not null", (Symbol("a") === Symbol("b")).isNotNull) } test("is distinct expressions") { - assertEqual("a is distinct from b", !('a <=> 'b)) - assertEqual("a is not distinct from b", 'a <=> 'b) + assertEqual("a is distinct from b", !(Symbol("a") <=> Symbol("b"))) + assertEqual("a is not distinct from b", Symbol("a") <=> Symbol("b")) } test("binary arithmetic expressions") { // Simple operations - assertEqual("a * b", 'a * 'b) - assertEqual("a / b", 'a / 'b) - assertEqual("a DIV b", 'a div 'b) - assertEqual("a % b", 'a % 'b) - assertEqual("a + b", 'a + 'b) - assertEqual("a - b", 'a - 'b) - assertEqual("a & b", 'a & 'b) - assertEqual("a ^ b", 'a ^ 'b) - assertEqual("a | b", 'a | 'b) + assertEqual("a * b", Symbol("a") * Symbol("b")) + assertEqual("a / b", Symbol("a") / Symbol("b")) + assertEqual("a DIV b", Symbol("a") div Symbol("b")) + assertEqual("a % b", Symbol("a") % Symbol("b")) + assertEqual("a + b", Symbol("a") + Symbol("b")) + assertEqual("a - b", Symbol("a") - Symbol("b")) + assertEqual("a & b", Symbol("a") & Symbol("b")) + assertEqual("a ^ b", Symbol("a") ^ Symbol("b")) + assertEqual("a | b", Symbol("a") | Symbol("b")) // Check precedences assertEqual( "a * t | b ^ c & d - e + f % g DIV h / i * k", - 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g div 'h) / 'i * 'k))))) + Symbol("a") * Symbol("t") | (Symbol("b") ^ (Symbol("c") & (Symbol("d") - Symbol("e") + + ((Symbol("f") % Symbol("g") div Symbol("h")) / Symbol("i") * Symbol("k")))))) } test("unary arithmetic expressions") { - assertEqual("+a", +'a) - assertEqual("-a", -'a) - assertEqual("~a", ~'a) - assertEqual("-+~~a", -( +(~(~'a)))) + assertEqual("+a", +Symbol("a")) + assertEqual("-a", -Symbol("a")) + assertEqual("~a", ~Symbol("a")) + assertEqual("-+~~a", -( +(~(~Symbol("a"))))) } test("cast expressions") { // Note that DataType parsing is tested elsewhere. - assertEqual("cast(a as int)", 'a.cast(IntegerType)) - assertEqual("cast(a as timestamp)", 'a.cast(TimestampType)) - assertEqual("cast(a as array)", 'a.cast(ArrayType(IntegerType))) - assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType)) + assertEqual("cast(a as int)", Symbol("a").cast(IntegerType)) + assertEqual("cast(a as timestamp)", Symbol("a").cast(TimestampType)) + assertEqual("cast(a as array)", Symbol("a").cast(ArrayType(IntegerType))) + assertEqual("cast(cast(a as int) as long)", Symbol("a").cast(IntegerType).cast(LongType)) } test("function expressions") { - assertEqual("foo()", 'foo.function()) + assertEqual("foo()", Symbol("foo").function()) assertEqual("foo.bar()", UnresolvedFunction(FunctionIdentifier("bar", Some("foo")), Seq.empty, isDistinct = false)) - assertEqual("foo(*)", 'foo.function(star())) - assertEqual("count(*)", 'count.function(1)) - assertEqual("foo(a, b)", 'foo.function('a, 'b)) - assertEqual("foo(all a, b)", 'foo.function('a, 'b)) - assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) - assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) - assertEqual("`select`(all a, b)", 'select.function('a, 'b)) + assertEqual("foo(*)", Symbol("foo").function(star())) + assertEqual("count(*)", Symbol("count").function(1)) + assertEqual("foo(a, b)", Symbol("foo").function(Symbol("a"), Symbol("b"))) + assertEqual("foo(all a, b)", Symbol("foo").function(Symbol("a"), Symbol("b"))) + assertEqual("foo(distinct a, b)", Symbol("foo").distinctFunction(Symbol("a"), Symbol("b"))) + assertEqual("grouping(distinct a, b)", + Symbol("grouping").distinctFunction(Symbol("a"), Symbol("b"))) + assertEqual("`select`(all a, b)", Symbol("select").function(Symbol("a"), Symbol("b"))) intercept("foo(a x)", "extraneous input 'x'") } private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name)) test("lambda functions") { - assertEqual("x -> x + 1", LambdaFunction(lv('x) + 1, Seq(lv('x)))) - assertEqual("(x, y) -> x + y", LambdaFunction(lv('x) + lv('y), Seq(lv('x), lv('y)))) + assertEqual("x -> x + 1", LambdaFunction(lv(Symbol("x")) + 1, Seq(lv(Symbol("x"))))) + assertEqual("(x, y) -> x + y", + LambdaFunction(lv(Symbol("x")) + lv(Symbol("y")), Seq(lv(Symbol("x")), lv(Symbol("y"))))) } test("window function expressions") { - val func = 'foo.function(star()) + val func = Symbol("foo").function(star()) def windowed( partitioning: Seq[Expression] = Seq.empty, ordering: Seq[SortOrder] = Seq.empty, @@ -301,27 +309,32 @@ class ExpressionParserSuite extends AnalysisTest { // Basic window testing. assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1"))) assertEqual("foo(*) over ()", windowed()) - assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc))) - assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc))) - assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) - assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) + assertEqual("foo(*) over (partition by a, b)", windowed(Seq(Symbol("a"), Symbol("b")))) + assertEqual("foo(*) over (distribute by a, b)", windowed(Seq(Symbol("a"), Symbol("b")))) + assertEqual("foo(*) over (cluster by a, b)", windowed(Seq(Symbol("a"), Symbol("b")))) + assertEqual("foo(*) over (order by a desc, b asc)", + windowed(Seq.empty, Seq(Symbol("a").desc, Symbol("b").asc))) + assertEqual("foo(*) over (sort by a desc, b asc)", + windowed(Seq.empty, Seq(Symbol("a").desc, Symbol("b").asc))) + assertEqual("foo(*) over (partition by a, b order by c)", + windowed(Seq(Symbol("a"), Symbol("b")), Seq(Symbol("c").asc))) + assertEqual("foo(*) over (distribute by a, b sort by c)", + windowed(Seq(Symbol("a"), Symbol("b")), Seq(Symbol("c").asc))) // Test use of expressions in window functions. assertEqual( "sum(product + 1) over (partition by ((product) + (1)) order by 2)", - WindowExpression('sum.function('product + 1), - WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + WindowExpression(Symbol("sum").function(Symbol("product") + 1), + WindowSpecDefinition(Seq(Symbol("product") + 1), Seq(Literal(2).asc), UnspecifiedFrame))) assertEqual( "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", - WindowExpression('sum.function('product + 1), - WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + WindowExpression(Symbol("sum").function(Symbol("product") + 1), + WindowSpecDefinition( + Seq(Symbol("product") / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) } test("range/rows window function expressions") { - val func = 'foo.function(star()) + val func = Symbol("foo").function(star()) def windowed( partitioning: Seq[Expression] = Seq.empty, ordering: Seq[SortOrder] = Seq.empty, @@ -380,7 +393,8 @@ class ExpressionParserSuite extends AnalysisTest { boundaries.foreach { case (boundarySql, begin, end) => val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)" - val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end)) + val expr = windowed(Seq(Symbol("a")), Seq(Symbol("b").asc), + SpecifiedWindowFrame(frameType, begin, end)) assertEqual(query, expr) } } @@ -392,65 +406,70 @@ class ExpressionParserSuite extends AnalysisTest { test("row constructor") { // Note that '(a)' will be interpreted as a nested expression. - assertEqual("(a, b)", CreateStruct(Seq('a, 'b))) - assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c))) - assertEqual("(a as b, b as c)", CreateStruct(Seq('a as 'b, 'b as 'c))) + assertEqual("(a, b)", CreateStruct(Seq(Symbol("a"), Symbol("b")))) + assertEqual("(a, b, c)", CreateStruct(Seq(Symbol("a"), Symbol("b"), Symbol("c")))) + assertEqual("(a as b, b as c)", + CreateStruct(Seq(Symbol("a") as Symbol("b"), Symbol("b") as Symbol("c")))) } test("scalar sub-query") { assertEqual( "(select max(val) from tbl) > current", - ScalarSubquery(table("tbl").select('max.function('val))) > 'current) + ScalarSubquery( + table("tbl").select(Symbol("max").function(Symbol("val")))) > Symbol("current")) assertEqual( "a = (select b from s)", - 'a === ScalarSubquery(table("s").select('b))) + Symbol("a") === ScalarSubquery(table("s").select(Symbol("b")))) } test("case when") { assertEqual("case a when 1 then b when 2 then c else d end", - CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd))) + CaseKeyWhen(Symbol("a"), Seq(1, Symbol("b"), 2, Symbol("c"), Symbol("d")))) assertEqual("case (a or b) when true then c when false then d else e end", - CaseKeyWhen('a || 'b, Seq(true, 'c, false, 'd, 'e))) + CaseKeyWhen(Symbol("a") || Symbol("b"), + Seq(true, Symbol("c"), false, Symbol("d"), Symbol("e")))) assertEqual("case 'a'='a' when true then 1 end", CaseKeyWhen("a" === "a", Seq(true, 1))) assertEqual("case when a = 1 then b when a = 2 then c else d end", - CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd)) + CaseWhen(Seq((Symbol("a") === 1, Symbol("b").expr), + (Symbol("a") === 2, Symbol("c").expr)), Symbol("d"))) assertEqual("case when (1) + case when a > b then c else d end then f else g end", - CaseWhen(Seq((Literal(1) + CaseWhen(Seq(('a > 'b, 'c.expr)), 'd.expr), 'f.expr)), 'g)) + CaseWhen(Seq((Literal(1) + CaseWhen(Seq((Symbol("a") > Symbol("b"), + Symbol("c").expr)), Symbol("d").expr), Symbol("f").expr)), Symbol("g"))) } test("dereference") { assertEqual("a.b", UnresolvedAttribute("a.b")) assertEqual("`select`.b", UnresolvedAttribute("select.b")) - assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. + assertEqual("(a + b).b", (Symbol("a") + Symbol("b")).getField("b")) // This will fail analysis. assertEqual( "struct(a, b).b", - namedStruct(NamePlaceholder, 'a, NamePlaceholder, 'b).getField("b")) + namedStruct(NamePlaceholder, Symbol("a"), NamePlaceholder, Symbol("b")).getField("b")) } test("reference") { // Regular - assertEqual("a", 'a) + assertEqual("a", Symbol("a")) // Starting with a digit. assertEqual("1a", Symbol("1a")) // Quoted using a keyword. - assertEqual("`select`", 'select) + assertEqual("`select`", Symbol("select")) // Unquoted using an unreserved keyword. - assertEqual("columns", 'columns) + assertEqual("columns", Symbol("columns")) } test("subscript") { - assertEqual("a[b]", 'a.getItem('b)) - assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1)) - assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b)) + assertEqual("a[b]", Symbol("a").getItem(Symbol("b"))) + assertEqual("a[1 + 1]", Symbol("a").getItem(Literal(1) + 1)) + assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem(Symbol("b"))) } test("parenthesis") { - assertEqual("(a)", 'a) - assertEqual("r * (a + b)", 'r * ('a + 'b)) + assertEqual("(a)", Symbol("a")) + assertEqual("r * (a + b)", Symbol("r") * (Symbol("a") + Symbol("b"))) } test("type constructors") { @@ -759,7 +778,8 @@ class ExpressionParserSuite extends AnalysisTest { test("composed expressions") { assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q")) - assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar))) + assertEqual("1 - f('o', o(bar))", + Literal(1) - Symbol("f").function("o", Symbol("o").function(Symbol("bar")))) intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") } @@ -786,10 +806,10 @@ class ExpressionParserSuite extends AnalysisTest { } test("SPARK-19526 Support ignore nulls keywords for first and last") { - assertEqual("first(a ignore nulls)", First('a, true).toAggregateExpression()) - assertEqual("first(a)", First('a, false).toAggregateExpression()) - assertEqual("last(a ignore nulls)", Last('a, true).toAggregateExpression()) - assertEqual("last(a)", Last('a, false).toAggregateExpression()) + assertEqual("first(a ignore nulls)", First(Symbol("a"), true).toAggregateExpression()) + assertEqual("first(a)", First(Symbol("a"), false).toAggregateExpression()) + assertEqual("last(a ignore nulls)", Last(Symbol("a"), true).toAggregateExpression()) + assertEqual("last(a)", Last(Symbol("a"), false).toAggregateExpression()) } test("timestamp literals") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 1ca2f5226903..1ca888df3e9e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -185,7 +185,8 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select * from a minus distinct select * from b", a.except(b, isAll = false)) assertEqual("select * from a " + "intersect select * from b", a.intersect(b, isAll = false)) - assertEqual("select * from a intersect distinct select * from b", a.intersect(b, isAll = false)) + assertEqual("select * from a intersect distinct select * from b", + a.intersect(b, isAll = false)) assertEqual("select * from a intersect all select * from b", a.intersect(b, isAll = true)) } @@ -208,23 +209,26 @@ class PlanParserSuite extends AnalysisTest { test("simple select query") { assertEqual("select 1", OneRowRelation().select(1)) - assertEqual("select a, b", OneRowRelation().select('a, 'b)) - assertEqual("select a, b from db.c", table("db", "c").select('a, 'b)) - assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) + assertEqual("select a, b", OneRowRelation().select(Symbol("a"), Symbol("b"))) + assertEqual("select a, b from db.c", table("db", "c").select(Symbol("a"), Symbol("b"))) + assertEqual("select a, b from db.c where x < 1", + table("db", "c").where(Symbol("x") < 1).select(Symbol("a"), Symbol("b"))) assertEqual( "select a, b from db.c having x < 1", - table("db", "c").having()('a, 'b)('x < 1)) - assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) - assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) - assertEqual("select from tbl", OneRowRelation().select('from.as("tbl"))) - assertEqual("select a from 1k.2m", table("1k", "2m").select('a)) + table("db", "c").having()(Symbol("a"), Symbol("b"))(Symbol("x") < 1)) + assertEqual("select distinct a, b from db.c", + Distinct(table("db", "c").select(Symbol("a"), Symbol("b")))) + assertEqual("select all a, b from db.c", table("db", "c").select(Symbol("a"), Symbol("b"))) + assertEqual("select from tbl", OneRowRelation().select(Symbol("from").as("tbl"))) + assertEqual("select a from 1k.2m", table("1k", "2m").select(Symbol("a"))) } test("hive-style single-FROM statement") { - assertEqual("from a select b, c", table("a").select('b, 'c)) - assertEqual( - "from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c)) - assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c))) + assertEqual("from a select b, c", table("a").select(Symbol("b"), Symbol("c"))) + assertEqual("from db.a select b, c where d < 1", + table("db", "a").where(Symbol("d") < 1).select(Symbol("b"), Symbol("c"))) + assertEqual("from a select distinct b, c", + Distinct(table("a").select(Symbol("b"), Symbol("c")))) // Weird "FROM table" queries, should be invalid anyway intercept("from a", "no viable alternative at input 'from a'") @@ -234,7 +238,7 @@ class PlanParserSuite extends AnalysisTest { test("multi select query") { assertEqual( "from a select * select * where s < 10", - table("a").select(star()).union(table("a").where('s < 10).select(star()))) + table("a").select(star()).union(table("a").where(Symbol("s") < 10).select(star()))) intercept( "from a select * select * from x where a.s < 10", "mismatched input 'from' expecting") @@ -244,7 +248,7 @@ class PlanParserSuite extends AnalysisTest { assertEqual( "from a insert into tbl1 select * insert into tbl2 select * where s < 10", table("a").select(star()).insertInto("tbl1").union( - table("a").where('s < 10).select(star()).insertInto("tbl2"))) + table("a").where(Symbol("s") < 10).select(star()).insertInto("tbl2"))) assertEqual( "select * from (from a select * select *)", table("a").select(star()) @@ -267,8 +271,8 @@ class PlanParserSuite extends AnalysisTest { val orderSortDistrClusterClauses = Seq( ("", basePlan), - (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)), - (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)) + (" order by a, b desc", basePlan.orderBy(Symbol("a").asc, Symbol("b").desc)), + (" sort by a, b desc", basePlan.sortBy(Symbol("a").asc, Symbol("b").desc)) ) orderSortDistrClusterClauses.foreach { @@ -308,7 +312,7 @@ class PlanParserSuite extends AnalysisTest { insert(Map("c" -> Option("d"), "e" -> Option("1")))) // Multi insert - val plan2 = table("t").where('x > 5).select(star()) + val plan2 = table("t").where(Symbol("x") > 5).select(star()) assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", plan.limit(1).insertInto("s").union(plan2.insertInto("u"))) } @@ -317,20 +321,24 @@ class PlanParserSuite extends AnalysisTest { val sql = "select a, b, sum(c) as c from d group by a, b" // Normal - assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c"))) + assertEqual(sql, table("d").groupBy(Symbol("a"), + Symbol("b"))(Symbol("a"), Symbol("b"), Symbol("sum").function(Symbol("c")).as("c"))) // Cube assertEqual(s"$sql with cube", - table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + table("d").groupBy(Cube(Seq(Symbol("a"), + Symbol("b"))))(Symbol("a"), Symbol("b"), Symbol("sum").function(Symbol("c")).as("c"))) // Rollup assertEqual(s"$sql with rollup", - table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + table("d").groupBy(Rollup(Seq(Symbol("a"), + Symbol("b"))))(Symbol("a"), Symbol("b"), Symbol("sum").function(Symbol("c")).as("c"))) // Grouping Sets assertEqual(s"$sql grouping sets((a, b), (a), ())", - GroupingSets(Seq(Seq('a, 'b), Seq('a), Seq()), Seq('a, 'b), table("d"), - Seq('a, 'b, 'sum.function('c).as("c")))) + GroupingSets(Seq(Seq(Symbol("a"), Symbol("b")), + Seq(Symbol("a")), Seq()), Seq(Symbol("a"), Symbol("b")), table("d"), + Seq(Symbol("a"), Symbol("b"), Symbol("sum").function(Symbol("c")).as("c")))) val m = intercept[ParseException] { parsePlan("SELECT a, b, count(distinct a, distinct b) as c FROM d GROUP BY a, b") @@ -350,7 +358,7 @@ class PlanParserSuite extends AnalysisTest { // Note that WindowSpecs are testing in the ExpressionParserSuite val sql = "select * from t" val plan = table("t").select(star()) - val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc), + val spec = WindowSpecDefinition(Seq(Symbol("a"), Symbol("b")), Seq(Symbol("c").asc), SpecifiedWindowFrame(RowFrame, -Literal(1), Literal(1))) // Test window resolution. @@ -376,8 +384,9 @@ class PlanParserSuite extends AnalysisTest { } test("lateral view") { - val explode = UnresolvedGenerator(FunctionIdentifier("explode"), Seq('x)) - val jsonTuple = UnresolvedGenerator(FunctionIdentifier("json_tuple"), Seq('x, 'y)) + val explode = UnresolvedGenerator(FunctionIdentifier("explode"), Seq(Symbol("x"))) + val jsonTuple = + UnresolvedGenerator(FunctionIdentifier("json_tuple"), Seq(Symbol("x"), Symbol("y"))) // Single lateral view assertEqual( @@ -413,12 +422,12 @@ class PlanParserSuite extends AnalysisTest { .generate(jsonTuple, alias = Some("jtup"), outputNames = Seq("q", "z")) .select(star()) .insertInto("t2"), - from.where('s < 10).select(star()).insertInto("t3"))) + from.where(Symbol("s") < 10).select(star()).insertInto("t3"))) // Unresolved generator. val expected = table("t") .generate( - UnresolvedGenerator(FunctionIdentifier("posexplode"), Seq('x)), + UnresolvedGenerator(FunctionIdentifier("posexplode"), Seq(Symbol("x"))), alias = Some("posexpl"), outputNames = Seq("x", "y")) .select(star()) @@ -447,7 +456,8 @@ class PlanParserSuite extends AnalysisTest { val testConditionalJoin = (sql: String, jt: JoinType) => { assertEqual( s"select * from t $sql u as uu on a = b", - table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star())) + table("t").join(table("u").as("uu"), jt, + Option(Symbol("a") === Symbol("b"))).select(star())) } val testNaturalJoin = (sql: String, jt: JoinType) => { assertEqual( @@ -507,17 +517,20 @@ class PlanParserSuite extends AnalysisTest { "select * from t1 inner join (t2 inner join t3 on col3 = col2) on col3 = col1", table("t1") .join(table("t2") - .join(table("t3"), Inner, Option('col3 === 'col2)), Inner, Option('col3 === 'col1)) + .join(table("t3"), Inner, Option(Symbol("col3") === Symbol("col2"))), + Inner, Option(Symbol("col3") === Symbol("col1"))) .select(star())) assertEqual( "select * from t1 inner join (t2 inner join t3) on col3 = col2", table("t1") - .join(table("t2").join(table("t3"), Inner, None), Inner, Option('col3 === 'col2)) + .join(table("t2").join(table("t3"), Inner, None), + Inner, Option(Symbol("col3") === Symbol("col2"))) .select(star())) assertEqual( "select * from t1 inner join (t2 inner join t3 on col3 = col2)", table("t1") - .join(table("t2").join(table("t3"), Inner, Option('col3 === 'col2)), Inner, None) + .join(table("t2").join(table("t3"), + Inner, Option(Symbol("col3") === Symbol("col2"))), Inner, None) .select(star())) // Implicit joins. @@ -548,7 +561,7 @@ class PlanParserSuite extends AnalysisTest { } test("sub-query") { - val plan = table("t0").select('id) + val plan = table("t0").select(Symbol("id")) assertEqual("select id from (t0)", plan) assertEqual("select id from ((((((t0))))))", plan) assertEqual( @@ -566,20 +579,23 @@ class PlanParserSuite extends AnalysisTest { | union all | (select id from t0)) as u_1 """.stripMargin, - plan.union(plan).union(plan).as("u_1").select('id)) + plan.union(plan).union(plan).as("u_1").select(Symbol("id"))) } test("scalar sub-query") { assertEqual( "select (select max(b) from s) ss from t", - table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss"))) + table("t") + .select(ScalarSubquery(table("s").select(Symbol("max").function(Symbol("b")))).as("ss"))) assertEqual( "select * from t where a = (select b from s)", - table("t").where('a === ScalarSubquery(table("s").select('b))).select(star())) + table("t") + .where(Symbol("a") === ScalarSubquery(table("s").select(Symbol("b")))).select(star())) assertEqual( "select g from t group by g having a > (select b from s)", table("t") - .having('g)('g)('a > ScalarSubquery(table("s").select('b)))) + .having( + Symbol("g"))(Symbol("g"))(Symbol("a") > ScalarSubquery(table("s").select(Symbol("b"))))) } test("table reference") { @@ -623,7 +639,7 @@ class PlanParserSuite extends AnalysisTest { "t", UnresolvedSubqueryColumnAliases( Seq("col1", "col2"), - UnresolvedRelation(TableIdentifier("t")).select('a.as("x"), 'b.as("y")) + UnresolvedRelation(TableIdentifier("t")).select(Symbol("a").as("x"), Symbol("b").as("y")) ) ).select(star())) } @@ -649,7 +665,7 @@ class PlanParserSuite extends AnalysisTest { "t", UnresolvedSubqueryColumnAliases( Seq("col1", "col2"), - UnresolvedRelation(TableIdentifier("t")).select('a.as("x"), 'b.as("y"))) + UnresolvedRelation(TableIdentifier("t")).select(Symbol("a").as("x"), Symbol("b").as("y"))) ).select($"t.col1", $"t.col2") ) } @@ -668,10 +684,10 @@ class PlanParserSuite extends AnalysisTest { test("simple select query with !> and !<") { // !< is equivalent to >= assertEqual("select a, b from db.c where x !< 1", - table("db", "c").where('x >= 1).select('a, 'b)) + table("db", "c").where(Symbol("x") >= 1).select(Symbol("a"), Symbol("b"))) // !> is equivalent to <= assertEqual("select a, b from db.c where x !> 1", - table("db", "c").where('x <= 1).select('a, 'b)) + table("db", "c").where(Symbol("x") <= 1).select(Symbol("a"), Symbol("b"))) } test("select hint syntax") { @@ -715,7 +731,7 @@ class PlanParserSuite extends AnalysisTest { comparePlans( parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"), UnresolvedHint("MAPJOIN", Seq($"t"), - table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc)) + table("t").where(Literal(true)).groupBy(Symbol("a"))(Symbol("a"))).orderBy(Symbol("a").asc)) comparePlans( parsePlan("SELECT /*+ COALESCE(10) */ * FROM t"), @@ -1039,14 +1055,14 @@ class PlanParserSuite extends AnalysisTest { test("CTE with column alias") { assertEqual( "WITH t(x) AS (SELECT c FROM a) SELECT * FROM t", - cte(table("t").select(star()), "t" -> ((table("a").select('c), Seq("x"))))) + cte(table("t").select(star()), "t" -> ((table("a").select(Symbol("c")), Seq("x"))))) } test("statement containing terminal semicolons") { assertEqual("select 1;", OneRowRelation().select(1)) - assertEqual("select a, b;", OneRowRelation().select('a, 'b)) - assertEqual("select a, b from db.c;;;", table("db", "c").select('a, 'b)) - assertEqual("select a, b from db.c; ;; ;", table("db", "c").select('a, 'b)) + assertEqual("select a, b;", OneRowRelation().select(Symbol("a"), Symbol("b"))) + assertEqual("select a, b from db.c;;;", table("db", "c").select(Symbol("a"), Symbol("b"))) + assertEqual("select a, b from db.c; ;; ;", table("db", "c").select(Symbol("a"), Symbol("b"))) } test("SPARK-32106: TRANSFORM plan") { @@ -1058,7 +1074,7 @@ class PlanParserSuite extends AnalysisTest { |FROM testData """.stripMargin, ScriptTransformation( - Seq('a, 'b, 'c), + Seq(Symbol("a"), Symbol("b"), Symbol("c")), "cat", Seq(AttributeReference("key", StringType)(), AttributeReference("value", StringType)()), @@ -1075,7 +1091,7 @@ class PlanParserSuite extends AnalysisTest { |FROM testData """.stripMargin, ScriptTransformation( - Seq('a, 'b, 'c), + Seq(Symbol("a"), Symbol("b"), Symbol("c")), "cat", Seq(AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), @@ -1092,7 +1108,7 @@ class PlanParserSuite extends AnalysisTest { |FROM testData """.stripMargin, ScriptTransformation( - Seq('a, 'b, 'c), + Seq(Symbol("a"), Symbol("b"), Symbol("c")), "cat", Seq(AttributeReference("a", IntegerType)(), AttributeReference("b", StringType)(), @@ -1121,7 +1137,7 @@ class PlanParserSuite extends AnalysisTest { |FROM testData """.stripMargin, ScriptTransformation( - Seq('a, 'b, 'c), + Seq(Symbol("a"), Symbol("b"), Symbol("c")), "cat", Seq(AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 5ad748b6113d..360a15901112 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -57,22 +57,23 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { } test("propagating constraints in filters") { - val tr = LocalRelation('a.int, 'b.string, 'c.int) + val tr = LocalRelation(Symbol("a").int, Symbol("b").string, Symbol("c").int) assert(tr.analyze.constraints.isEmpty) - assert(tr.where('a.attr > 10).select('c.attr, 'b.attr).analyze.constraints.isEmpty) + assert(tr.where(Symbol("a").attr > 10).select(Symbol("c").attr, + Symbol("b").attr).analyze.constraints.isEmpty) verifyConstraints(tr - .where('a.attr > 10) + .where(Symbol("a").attr > 10) .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr, "a") > 10, IsNotNull(resolveColumn(tr, "a"))))) verifyConstraints(tr - .where('a.attr > 10) - .select('c.attr, 'a.attr) - .where('c.attr =!= 100) + .where(Symbol("a").attr > 10) + .select(Symbol("c").attr, Symbol("a").attr) + .where(Symbol("c").attr =!= 100) .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr, "a") > 10, resolveColumn(tr, "c") =!= 100, @@ -81,12 +82,14 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { } test("propagating constraints in aggregate") { - val tr = LocalRelation('a.int, 'b.string, 'c.int) + val tr = LocalRelation(Symbol("a").int, Symbol("b").string, Symbol("c").int) assert(tr.analyze.constraints.isEmpty) - val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) - .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3).analyze + val aliasedRelation = tr.where(Symbol("c").attr > 10 && Symbol("a").attr < 5) + .groupBy(Symbol("a"), Symbol("c"), Symbol("b"))( + Symbol("a"), Symbol("c").as("c1"), count(Symbol("a")).as("a3")) + .select(Symbol("c1"), Symbol("a"), Symbol("a3")).analyze // SPARK-16644: aggregate expression count(a) should not appear in the constraints. verifyConstraints(aliasedRelation.analyze.constraints, @@ -98,13 +101,14 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { } test("propagating constraints in expand") { - val tr = LocalRelation('a.int, 'b.int, 'c.int) + val tr = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) assert(tr.analyze.constraints.isEmpty) // We add IsNotNull constraints for 'a, 'b and 'c into LocalRelation // by creating notNullRelation. - val notNullRelation = tr.where('c.attr > 10 && 'a.attr < 5 && 'b.attr > 2) + val notNullRelation = + tr.where(Symbol("c").attr > 10 && Symbol("a").attr < 5 && Symbol("b").attr > 2) verifyConstraints(notNullRelation.analyze.constraints, ExpressionSet(Seq(resolveColumn(notNullRelation.analyze, "c") > 10, IsNotNull(resolveColumn(notNullRelation.analyze, "c")), @@ -115,31 +119,36 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { val expand = Expand( Seq( - Seq('c, Literal.create(null, StringType), 1), - Seq('c, 'a, 2)), - Seq('c, 'a, 'gid.int), - Project(Seq('a, 'c), + Seq(Symbol("c"), Literal.create(null, StringType), 1), + Seq(Symbol("c"), Symbol("a"), 2)), + Seq(Symbol("c"), Symbol("a"), Symbol("gid").int), + Project(Seq(Symbol("a"), Symbol("c")), notNullRelation)) verifyConstraints(expand.analyze.constraints, ExpressionSet(Seq.empty[Expression])) } test("propagating constraints in aliases") { - val tr = LocalRelation('a.int, 'b.string, 'c.int) + val tr = LocalRelation(Symbol("a").int, Symbol("b").string, Symbol("c").int) - assert(tr.where('c.attr > 10).select('a.as('x), 'b.as('y)).analyze.constraints.isEmpty) + assert(tr.where(Symbol("c").attr > 10).select(Symbol("a").as(Symbol("x")), + Symbol("b").as(Symbol("y"))).analyze.constraints.isEmpty) - val aliasedRelation = tr.where('a.attr > 10).select('a.as('x), 'b, 'b.as('y), 'a.as('z)) + val aliasedRelation = tr.where(Symbol("a").attr > 10).select(Symbol("a").as(Symbol("x")), + Symbol("b"), Symbol("b").as(Symbol("y")), Symbol("a").as(Symbol("z"))) verifyConstraints(aliasedRelation.analyze.constraints, ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), - resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), - resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"), + resolveColumn(aliasedRelation.analyze, "b") <=> + resolveColumn(aliasedRelation.analyze, "y"), + resolveColumn(aliasedRelation.analyze, "z") <=> + resolveColumn(aliasedRelation.analyze, "x"), resolveColumn(aliasedRelation.analyze, "z") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))) - val multiAlias = tr.where('a === 'c + 10).select('a.as('x), 'c.as('y)) + val multiAlias = tr.where(Symbol("a") === Symbol("c") + 10).select(Symbol("a").as(Symbol("x")), + Symbol("c").as(Symbol("y"))) verifyConstraints(multiAlias.analyze.constraints, ExpressionSet(Seq(IsNotNull(resolveColumn(multiAlias.analyze, "x")), IsNotNull(resolveColumn(multiAlias.analyze, "y")), @@ -148,46 +157,46 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { } test("propagating constraints in union") { - val tr1 = LocalRelation('a.int, 'b.int, 'c.int) - val tr2 = LocalRelation('d.int, 'e.int, 'f.int) - val tr3 = LocalRelation('g.int, 'h.int, 'i.int) + val tr1 = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) + val tr2 = LocalRelation(Symbol("d").int, Symbol("e").int, Symbol("f").int) + val tr3 = LocalRelation(Symbol("g").int, Symbol("h").int, Symbol("i").int) assert(tr1 - .where('a.attr > 10) - .union(tr2.where('e.attr > 10) - .union(tr3.where('i.attr > 10))) + .where(Symbol("a").attr > 10) + .union(tr2.where(Symbol("e").attr > 10) + .union(tr3.where(Symbol("i").attr > 10))) .analyze.constraints.isEmpty) verifyConstraints(tr1 - .where('a.attr > 10) - .union(tr2.where('d.attr > 10) - .union(tr3.where('g.attr > 10))) + .where(Symbol("a").attr > 10) + .union(tr2.where(Symbol("d").attr > 10) + .union(tr3.where(Symbol("g").attr > 10))) .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, IsNotNull(resolveColumn(tr1, "a"))))) val a = resolveColumn(tr1, "a") verifyConstraints(tr1 - .where('a.attr > 10) - .union(tr2.where('d.attr > 11)) + .where(Symbol("a").attr > 10) + .union(tr2.where(Symbol("d").attr > 11)) .analyze.constraints, ExpressionSet(Seq(a > 10 || a > 11, IsNotNull(a)))) val b = resolveColumn(tr1, "b") verifyConstraints(tr1 - .where('a.attr > 10 && 'b.attr < 10) - .union(tr2.where('d.attr > 11 && 'e.attr < 11)) + .where(Symbol("a").attr > 10 && Symbol("b").attr < 10) + .union(tr2.where(Symbol("d").attr > 11 && Symbol("e").attr < 11)) .analyze.constraints, ExpressionSet(Seq(a > 10 || a > 11, b < 10 || b < 11, IsNotNull(a), IsNotNull(b)))) } test("propagating constraints in intersect") { - val tr1 = LocalRelation('a.int, 'b.int, 'c.int) - val tr2 = LocalRelation('a.int, 'b.int, 'c.int) + val tr1 = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) + val tr2 = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) verifyConstraints(tr1 - .where('a.attr > 10) - .intersect(tr2.where('b.attr < 100), isAll = false) + .where(Symbol("a").attr > 10) + .intersect(tr2.where(Symbol("b").attr < 100), isAll = false) .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100, @@ -196,22 +205,24 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { } test("propagating constraints in except") { - val tr1 = LocalRelation('a.int, 'b.int, 'c.int) - val tr2 = LocalRelation('a.int, 'b.int, 'c.int) + val tr1 = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) + val tr2 = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) verifyConstraints(tr1 - .where('a.attr > 10) - .except(tr2.where('b.attr < 100), isAll = false) + .where(Symbol("a").attr > 10) + .except(tr2.where(Symbol("b").attr < 100), isAll = false) .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, IsNotNull(resolveColumn(tr1, "a"))))) } test("propagating constraints in inner join") { - val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) - val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + val tr1 = + LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int).subquery(Symbol("tr1")) + val tr2 = + LocalRelation(Symbol("a").int, Symbol("d").int, Symbol("e").int).subquery(Symbol("tr2")) verifyConstraints(tr1 - .where('a.attr > 10) - .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) + .where(Symbol("a").attr > 10) + .join(tr2.where(Symbol("d").attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints, ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, @@ -224,51 +235,59 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { } test("propagating constraints in left-semi join") { - val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) - val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + val tr1 = + LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int).subquery(Symbol("tr1")) + val tr2 = + LocalRelation(Symbol("a").int, Symbol("d").int, Symbol("e").int).subquery(Symbol("tr2")) verifyConstraints(tr1 - .where('a.attr > 10) - .join(tr2.where('d.attr < 100), LeftSemi, Some("tr1.a".attr === "tr2.a".attr)) + .where(Symbol("a").attr > 10) + .join(tr2.where(Symbol("d").attr < 100), LeftSemi, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints, ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))) } test("propagating constraints in left-outer join") { - val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) - val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + val tr1 = + LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int).subquery(Symbol("tr1")) + val tr2 = + LocalRelation(Symbol("a").int, Symbol("d").int, Symbol("e").int).subquery(Symbol("tr2")) verifyConstraints(tr1 - .where('a.attr > 10) - .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr)) + .where(Symbol("a").attr > 10) + .join(tr2.where(Symbol("d").attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints, ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))) } test("propagating constraints in right-outer join") { - val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) - val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) + val tr1 = + LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int).subquery(Symbol("tr1")) + val tr2 = + LocalRelation(Symbol("a").int, Symbol("d").int, Symbol("e").int).subquery(Symbol("tr2")) verifyConstraints(tr1 - .where('a.attr > 10) - .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr)) + .where(Symbol("a").attr > 10) + .join(tr2.where(Symbol("d").attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints, ExpressionSet(Seq(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))) } test("propagating constraints in full-outer join") { - val tr1 = LocalRelation('a.int, 'b.int, 'c.int).subquery('tr1) - val tr2 = LocalRelation('a.int, 'd.int, 'e.int).subquery('tr2) - assert(tr1.where('a.attr > 10) - .join(tr2.where('d.attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr)) + val tr1 = + LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int).subquery(Symbol("tr1")) + val tr2 = + LocalRelation(Symbol("a").int, Symbol("d").int, Symbol("e").int).subquery(Symbol("tr2")) + assert(tr1.where(Symbol("a").attr > 10) + .join(tr2.where(Symbol("d").attr < 100), FullOuter, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints.isEmpty) } test("infer additional constraints in filters") { - val tr = LocalRelation('a.int, 'b.int, 'c.int) + val tr = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) verifyConstraints(tr - .where('a.attr > 10 && 'a.attr === 'b.attr) + .where(Symbol("a").attr > 10 && Symbol("a").attr === Symbol("b").attr) .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr, "a") > 10, resolveColumn(tr, "b") > 10, @@ -278,10 +297,11 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { } test("infer constraints on cast") { - val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int) + val tr = LocalRelation( + Symbol("a").int, Symbol("b").long, Symbol("c").int, Symbol("d").long, Symbol("e").int) verifyConstraints( - tr.where('a.attr === 'b.attr && - 'c.attr + 100 > 'd.attr && + tr.where(Symbol("a").attr === Symbol("b").attr && + Symbol("c").attr + 100 > Symbol("d").attr && IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints, ExpressionSet(Seq( castWithTimeZone(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"), @@ -291,16 +311,19 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { IsNotNull(resolveColumn(tr, "c")), IsNotNull(resolveColumn(tr, "d")), IsNotNull(resolveColumn(tr, "e")), - IsNotNull(castWithTimeZone(castWithTimeZone(resolveColumn(tr, "e"), LongType), LongType))))) + IsNotNull(castWithTimeZone(castWithTimeZone( + resolveColumn(tr, "e"), LongType), LongType))))) } test("infer isnotnull constraints from compound expressions") { - val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int) + val tr = LocalRelation( + Symbol("a").int, Symbol("b").long, Symbol("c").int, Symbol("d").long, Symbol("e").int) verifyConstraints( - tr.where('a.attr + 'b.attr === 'c.attr && + tr.where(Symbol("a").attr + Symbol("b").attr === Symbol("c").attr && IsNotNull( Cast( - Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints, + Cast( + Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints, ExpressionSet(Seq( castWithTimeZone(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") === castWithTimeZone(resolveColumn(tr, "c"), LongType), @@ -313,7 +336,8 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { resolveColumn(tr, "e"), LongType), LongType), LongType))))) verifyConstraints( - tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints, + tr.where((Symbol("a").attr * Symbol("b").attr + 100) === Symbol("c").attr && + Symbol("d") / 10 === Symbol("e")).analyze.constraints, ExpressionSet(Seq( castWithTimeZone(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + castWithTimeZone(100, LongType) === @@ -328,7 +352,8 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { IsNotNull(resolveColumn(tr, "e"))))) verifyConstraints( - tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints, + tr.where((Symbol("a").attr * Symbol("b").attr - 10) >= Symbol("c").attr && + Symbol("d") / 10 < Symbol("e")).analyze.constraints, ExpressionSet(Seq( castWithTimeZone(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - castWithTimeZone(10, LongType) >= @@ -343,7 +368,8 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { IsNotNull(resolveColumn(tr, "e"))))) verifyConstraints( - tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints, + tr.where(Symbol("a").attr + Symbol("b").attr - Symbol("c").attr * Symbol("d").attr > + Symbol("e").attr * 1000).analyze.constraints, ExpressionSet(Seq( (castWithTimeZone(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) - (castWithTimeZone(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) > @@ -356,7 +382,7 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { // The constraint IsNotNull(IsNotNull(expr)) doesn't guarantee expr is not null. verifyConstraints( - tr.where('a.attr === 'c.attr && + tr.where(Symbol("a").attr === Symbol("c").attr && IsNotNull(IsNotNull(resolveColumn(tr, "b")))).analyze.constraints, ExpressionSet(Seq( resolveColumn(tr, "a") === resolveColumn(tr, "c"), @@ -365,7 +391,7 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { IsNotNull(resolveColumn(tr, "c"))))) verifyConstraints( - tr.where('a.attr === 1 && IsNotNull(resolveColumn(tr, "b")) && + tr.where(Symbol("a").attr === 1 && IsNotNull(resolveColumn(tr, "b")) && IsNotNull(resolveColumn(tr, "c"))).analyze.constraints, ExpressionSet(Seq( resolveColumn(tr, "a") === 1, @@ -375,7 +401,8 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { } test("infer IsNotNull constraints from non-nullable attributes") { - val tr = LocalRelation('a.int, AttributeReference("b", IntegerType, nullable = false)(), + val tr = LocalRelation(Symbol("a").int, + AttributeReference("b", IntegerType, nullable = false)(), AttributeReference("c", StringType, nullable = false)()) verifyConstraints(tr.analyze.constraints, @@ -383,16 +410,16 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { } test("not infer non-deterministic constraints") { - val tr = LocalRelation('a.int, 'b.string, 'c.int) + val tr = LocalRelation(Symbol("a").int, Symbol("b").string, Symbol("c").int) verifyConstraints(tr - .where('a.attr === Rand(0)) + .where(Symbol("a").attr === Rand(0)) .analyze.constraints, ExpressionSet(Seq(IsNotNull(resolveColumn(tr, "a"))))) verifyConstraints(tr - .where('a.attr === InputFileName()) - .where('a.attr =!= 'c.attr) + .where(Symbol("a").attr === InputFileName()) + .where(Symbol("a").attr =!= Symbol("c").attr) .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr, "a") =!= resolveColumn(tr, "c"), IsNotNull(resolveColumn(tr, "a")), @@ -400,8 +427,8 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { } test("enable/disable constraint propagation") { - val tr = LocalRelation('a.int, 'b.string, 'c.int) - val filterRelation = tr.where('a.attr > 10) + val tr = LocalRelation(Symbol("a").int, Symbol("b").string, Symbol("c").int) + val filterRelation = tr.where(Symbol("a").attr > 10) withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") { assert(filterRelation.analyze.constraints.nonEmpty) @@ -411,8 +438,10 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest { assert(filterRelation.analyze.constraints.isEmpty) } - val aliasedRelation = tr.where('c.attr > 10 && 'a.attr < 5) - .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a, 'a3) + val aliasedRelation = tr.where(Symbol("c").attr > 10 && Symbol("a").attr < 5) + .groupBy(Symbol("a"), Symbol("c"), Symbol("b"))( + Symbol("a"), Symbol("c").as("c1"), count(Symbol("a")).as("a3")) + .select(Symbol("c1"), Symbol("a"), Symbol("a3")) withSQLConf(SQLConf.CONSTRAINT_PROPAGATION_ENABLED.key -> "true") { assert(aliasedRelation.analyze.constraints.nonEmpty) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala index 404c8895c4d1..fa5dfb48afa9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/QueryPlanSuite.scala @@ -97,7 +97,7 @@ class QueryPlanSuite extends SparkFunSuite { } } - val t = LocalRelation('a.int, 'b.int) + val t = LocalRelation(Symbol("a").int, Symbol("b").int) val plan = t.select($"a", $"b").select($"a", $"b").select($"a", $"b").analyze assert(testRule(plan).resolved) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index fbaaf807af5d..596f7a88fbcc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -29,8 +29,8 @@ import org.apache.spark.sql.catalyst.util._ * Tests for the sameResult function of [[LogicalPlan]]. */ class SameResultSuite extends SparkFunSuite { - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) + val testRelation2 = LocalRelation(Symbol("a").int, Symbol("b").int, Symbol("c").int) object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("EliminateResolvedHint", Once, EliminateResolvedHint) :: Nil @@ -51,21 +51,25 @@ class SameResultSuite extends SparkFunSuite { } test("projections") { - assertSameResult(testRelation.select('a), testRelation2.select('a)) - assertSameResult(testRelation.select('b), testRelation2.select('b)) - assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b)) - assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a)) + assertSameResult(testRelation.select(Symbol("a")), testRelation2.select(Symbol("a"))) + assertSameResult(testRelation.select(Symbol("b")), testRelation2.select(Symbol("b"))) + assertSameResult(testRelation.select(Symbol("a"), Symbol("b")), + testRelation2.select(Symbol("a"), Symbol("b"))) + assertSameResult(testRelation.select(Symbol("b"), Symbol("a")), + testRelation2.select(Symbol("b"), Symbol("a"))) - assertSameResult(testRelation, testRelation2.select('a), result = false) - assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), result = false) + assertSameResult(testRelation, testRelation2.select(Symbol("a")), result = false) + assertSameResult(testRelation.select(Symbol("b"), Symbol("a")), + testRelation2.select(Symbol("a"), Symbol("b")), result = false) } test("filters") { - assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b)) + assertSameResult(testRelation.where(Symbol("a") === Symbol("b")), + testRelation2.where(Symbol("a") === Symbol("b"))) } test("sorts") { - assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc)) + assertSameResult(testRelation.orderBy(Symbol("a").asc), testRelation2.orderBy(Symbol("a").asc)) } test("union") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala index 6f342b8d9437..e0a284a1908b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala @@ -31,7 +31,7 @@ class LogicalPlanIntegritySuite extends PlanTest { } test("Checks if the same `ExprId` refers to a semantically-equal attribute in a plan output") { - val t = LocalRelation('a.int, 'b.int) + val t = LocalRelation(Symbol("a").int, Symbol("b").int) assert(hasUniqueExprIdsForOutput(OutputTestPlan(t, t.output))) assert(!hasUniqueExprIdsForOutput(OutputTestPlan(t, t.output.zipWithIndex.map { case (a, i) => AttributeReference(s"c$i", LongType)(a.exprId) @@ -39,7 +39,7 @@ class LogicalPlanIntegritySuite extends PlanTest { } test("Checks if reference ExprIds are not reused when assigning a new ExprId") { - val t = LocalRelation('a.int, 'b.int) + val t = LocalRelation(Symbol("a").int, Symbol("b").int) val Seq(a, b) = t.output assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")()))) assert(!checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = a.exprId)))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 749fed394073..d97c1b101e61 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -52,7 +52,7 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { } test("windows") { - val windows = plan.window(Seq(min(attribute).as('sum_attr)), Seq(attribute), Nil) + val windows = plan.window(Seq(min(attribute).as(Symbol("sum_attr"))), Seq(attribute), Nil) val windowsStats = Statistics(sizeInBytes = plan.size.get * (4 + 4 + 8) / (4 + 8)) checkStats( windows, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 4f64de4ae875..5f0dbadf18eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -1012,17 +1012,17 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should throw an exception if any intermediate structs don't exist") { intercept[AnalysisException] { - structLevel2.withColumn("a", 'a.withField("x.b", lit(2))) + structLevel2.withColumn("a", Symbol("a").withField("x.b", lit(2))) }.getMessage should include("No such struct field x in a") intercept[AnalysisException] { - structLevel3.withColumn("a", 'a.withField("a.x.b", lit(2))) + structLevel3.withColumn("a", Symbol("a").withField("a.x.b", lit(2))) }.getMessage should include("No such struct field x in a") } test("withField should throw an exception if intermediate field is not a struct") { intercept[AnalysisException] { - structLevel1.withColumn("a", 'a.withField("b.a", lit(2))) + structLevel1.withColumn("a", Symbol("a").withField("b.a", lit(2))) }.getMessage should include("struct argument should be struct type, got: int") } @@ -1036,7 +1036,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("a", structType, nullable = false))), nullable = false)))) - structLevel2.withColumn("a", 'a.withField("a.b", lit(2))) + structLevel2.withColumn("a", Symbol("a").withField("a.b", lit(2))) }.getMessage should include("Ambiguous reference to fields") } @@ -1055,7 +1055,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add field to struct") { checkAnswer( - structLevel1.withColumn("a", 'a.withField("d", lit(4))), + structLevel1.withColumn("a", Symbol("a").withField("d", lit(4))), Row(Row(1, null, 3, 4)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1096,7 +1096,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add null field to struct") { checkAnswer( - structLevel1.withColumn("a", 'a.withField("d", lit(null).cast(IntegerType))), + structLevel1.withColumn("a", Symbol("a").withField("d", lit(null).cast(IntegerType))), Row(Row(1, null, 3, null)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1109,7 +1109,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add multiple fields to struct") { checkAnswer( - structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))), + structLevel1.withColumn("a", Symbol("a").withField("d", lit(4)).withField("e", lit(5))), Row(Row(1, null, 3, 4, 5)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1123,7 +1123,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add multiple fields to nullable struct") { checkAnswer( - nullableStructLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))), + nullableStructLevel1.withColumn( + "a", Symbol("a").withField("d", lit(4)).withField("e", lit(5))), Row(null) :: Row(Row(1, null, 3, 4, 5)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1137,8 +1138,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add field to nested struct") { Seq( - structLevel2.withColumn("a", 'a.withField("a.d", lit(4))), - structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("d", lit(4)))) + structLevel2.withColumn("a", Symbol("a").withField("a.d", lit(4))), + structLevel2.withColumn("a", Symbol("a").withField("a", $"a.a".withField("d", lit(4)))) ).foreach { df => checkAnswer( df, @@ -1199,7 +1200,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add field to deeply nested struct") { checkAnswer( - structLevel3.withColumn("a", 'a.withField("a.a.d", lit(4))), + structLevel3.withColumn("a", Symbol("a").withField("a.a.d", lit(4))), Row(Row(Row(Row(1, null, 3, 4)))) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1216,7 +1217,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace field in struct") { checkAnswer( - structLevel1.withColumn("a", 'a.withField("b", lit(2))), + structLevel1.withColumn("a", Symbol("a").withField("b", lit(2))), Row(Row(1, 2, 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1228,7 +1229,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace field in nullable struct") { checkAnswer( - nullableStructLevel1.withColumn("a", 'a.withField("b", lit("foo"))), + nullableStructLevel1.withColumn("a", Symbol("a").withField("b", lit("foo"))), Row(null) :: Row(Row(1, "foo", 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1254,7 +1255,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace field with null value in struct") { checkAnswer( - structLevel1.withColumn("a", 'a.withField("c", lit(null).cast(IntegerType))), + structLevel1.withColumn("a", Symbol("a").withField("c", lit(null).cast(IntegerType))), Row(Row(1, null, null)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1266,7 +1267,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace multiple fields in struct") { checkAnswer( - structLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))), + structLevel1.withColumn("a", Symbol("a").withField("a", lit(10)).withField("b", lit(20))), Row(Row(10, 20, 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1278,7 +1279,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace multiple fields in nullable struct") { checkAnswer( - nullableStructLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))), + nullableStructLevel1.withColumn( + "a", Symbol("a").withField("a", lit(10)).withField("b", lit(20))), Row(null) :: Row(Row(10, 20, 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1291,7 +1293,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace field in nested struct") { Seq( structLevel2.withColumn("a", $"a".withField("a.b", lit(2))), - structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("b", lit(2)))) + structLevel2.withColumn("a", Symbol("a").withField("a", $"a.a".withField("b", lit(2)))) ).foreach { df => checkAnswer( df, @@ -1372,7 +1374,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - structLevel1.withColumn("a", 'a.withField("b", lit(100))), + structLevel1.withColumn("a", Symbol("a").withField("b", lit(100))), Row(Row(1, 100, 100)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1384,7 +1386,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace fields in struct in given order") { checkAnswer( - structLevel1.withColumn("a", 'a.withField("b", lit(2)).withField("b", lit(20))), + structLevel1.withColumn("a", Symbol("a").withField("b", lit(2)).withField("b", lit(20))), Row(Row(1, 20, 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1396,7 +1398,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add field and then replace same field in struct") { checkAnswer( - structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("d", lit(5))), + structLevel1.withColumn("a", Symbol("a").withField("d", lit(4)).withField("d", lit(5))), Row(Row(1, null, 3, 5)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1420,7 +1422,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - df.withColumn("a", 'a.withField("`a.b`.`e.f`", lit(2))), + df.withColumn("a", Symbol("a").withField("`a.b`.`e.f`", lit(2))), Row(Row(Row(1, 2, 3))) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1432,7 +1434,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) intercept[AnalysisException] { - df.withColumn("a", 'a.withField("a.b.e.f", lit(2))) + df.withColumn("a", Symbol("a").withField("a.b.e.f", lit(2))) }.getMessage should include("No such struct field a in a.b") } @@ -1447,7 +1449,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace field in struct even if casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { checkAnswer( - mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))), + mixedCaseStructLevel1.withColumn("a", Symbol("a").withField("A", lit(2))), Row(Row(2, 1)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1456,7 +1458,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))), + mixedCaseStructLevel1.withColumn("a", Symbol("a").withField("b", lit(2))), Row(Row(1, 2)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1469,7 +1471,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add field to struct because casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { checkAnswer( - mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))), + mixedCaseStructLevel1.withColumn("a", Symbol("a").withField("A", lit(2))), Row(Row(1, 1, 2)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1479,7 +1481,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))), + mixedCaseStructLevel1.withColumn("a", Symbol("a").withField("b", lit(2))), Row(Row(1, 1, 2)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1507,7 +1509,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace nested field in struct even if casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { checkAnswer( - mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))), + mixedCaseStructLevel2.withColumn("a", Symbol("a").withField("A.a", lit(2))), Row(Row(Row(2, 1), Row(1, 1))) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1522,7 +1524,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))), + mixedCaseStructLevel2.withColumn("a", Symbol("a").withField("b.a", lit(2))), Row(Row(Row(1, 1), Row(2, 1))) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1541,11 +1543,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should throw an exception because casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { intercept[AnalysisException] { - mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))) + mixedCaseStructLevel2.withColumn("a", Symbol("a").withField("A.a", lit(2))) }.getMessage should include("No such struct field A in a, B") intercept[AnalysisException] { - mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))) + mixedCaseStructLevel2.withColumn("a", Symbol("a").withField("b.a", lit(2))) }.getMessage should include("No such struct field b in a, B") } } @@ -1697,17 +1699,17 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should throw an exception if any intermediate structs don't exist") { intercept[AnalysisException] { - structLevel2.withColumn("a", 'a.dropFields("x.b")) + structLevel2.withColumn("a", Symbol("a").dropFields("x.b")) }.getMessage should include("No such struct field x in a") intercept[AnalysisException] { - structLevel3.withColumn("a", 'a.dropFields("a.x.b")) + structLevel3.withColumn("a", Symbol("a").dropFields("a.x.b")) }.getMessage should include("No such struct field x in a") } test("dropFields should throw an exception if intermediate field is not a struct") { intercept[AnalysisException] { - structLevel1.withColumn("a", 'a.dropFields("b.a")) + structLevel1.withColumn("a", Symbol("a").dropFields("b.a")) }.getMessage should include("struct argument should be struct type, got: int") } @@ -1721,13 +1723,13 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("a", structType, nullable = false))), nullable = false)))) - structLevel2.withColumn("a", 'a.dropFields("a.b")) + structLevel2.withColumn("a", Symbol("a").dropFields("a.b")) }.getMessage should include("Ambiguous reference to fields") } test("dropFields should drop field in struct") { checkAnswer( - structLevel1.withColumn("a", 'a.dropFields("b")), + structLevel1.withColumn("a", Symbol("a").dropFields("b")), Row(Row(1, 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1750,7 +1752,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should drop multiple fields in struct") { Seq( structLevel1.withColumn("a", $"a".dropFields("b", "c")), - structLevel1.withColumn("a", 'a.dropFields("b").dropFields("c")) + structLevel1.withColumn("a", Symbol("a").dropFields("b").dropFields("c")) ).foreach { df => checkAnswer( df, @@ -1764,7 +1766,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should throw an exception if no fields will be left in struct") { intercept[AnalysisException] { - structLevel1.withColumn("a", 'a.dropFields("a", "b", "c")) + structLevel1.withColumn("a", Symbol("a").dropFields("a", "b", "c")) }.getMessage should include("cannot drop all fields in struct") } @@ -1788,7 +1790,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should drop field in nested struct") { checkAnswer( - structLevel2.withColumn("a", 'a.dropFields("a.b")), + structLevel2.withColumn("a", Symbol("a").dropFields("a.b")), Row(Row(Row(1, 3))) :: Nil, StructType( Seq(StructField("a", StructType(Seq( @@ -1801,7 +1803,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should drop multiple fields in nested struct") { checkAnswer( - structLevel2.withColumn("a", 'a.dropFields("a.b", "a.c")), + structLevel2.withColumn("a", Symbol("a").dropFields("a.b", "a.c")), Row(Row(Row(1))) :: Nil, StructType( Seq(StructField("a", StructType(Seq( @@ -1838,7 +1840,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should drop field in deeply nested struct") { checkAnswer( - structLevel3.withColumn("a", 'a.dropFields("a.a.b")), + structLevel3.withColumn("a", Symbol("a").dropFields("a.a.b")), Row(Row(Row(Row(1, 3)))) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1862,7 +1864,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - structLevel1.withColumn("a", 'a.dropFields("b")), + structLevel1.withColumn("a", Symbol("a").dropFields("b")), Row(Row(1)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1873,7 +1875,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should drop field in struct even if casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { checkAnswer( - mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")), + mixedCaseStructLevel1.withColumn("a", Symbol("a").dropFields("A")), Row(Row(1)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1881,7 +1883,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")), + mixedCaseStructLevel1.withColumn("a", Symbol("a").dropFields("b")), Row(Row(1)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1893,7 +1895,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should not drop field in struct because casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { checkAnswer( - mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")), + mixedCaseStructLevel1.withColumn("a", Symbol("a").dropFields("A")), Row(Row(1, 1)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1902,7 +1904,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")), + mixedCaseStructLevel1.withColumn("a", Symbol("a").dropFields("b")), Row(Row(1, 1)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1915,7 +1917,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should drop nested field in struct even if casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { checkAnswer( - mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")), + mixedCaseStructLevel2.withColumn("a", Symbol("a").dropFields("A.a")), Row(Row(Row(1), Row(1, 1))) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1929,7 +1931,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")), + mixedCaseStructLevel2.withColumn("a", Symbol("a").dropFields("b.a")), Row(Row(Row(1, 1), Row(1))) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1947,18 +1949,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should throw an exception because casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { intercept[AnalysisException] { - mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")) + mixedCaseStructLevel2.withColumn("a", Symbol("a").dropFields("A.a")) }.getMessage should include("No such struct field A in a, B") intercept[AnalysisException] { - mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")) + mixedCaseStructLevel2.withColumn("a", Symbol("a").dropFields("b.a")) }.getMessage should include("No such struct field b in a, B") } } test("dropFields should drop only fields that exist") { checkAnswer( - structLevel1.withColumn("a", 'a.dropFields("d")), + structLevel1.withColumn("a", Symbol("a").dropFields("d")), Row(Row(1, null, 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( @@ -1968,7 +1970,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - structLevel1.withColumn("a", 'a.dropFields("b", "d")), + structLevel1.withColumn("a", Symbol("a").dropFields("b", "d")), Row(Row(1, 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 28e2da3970a4..0735194adb17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -117,7 +117,7 @@ class DataFrameSuite extends QueryTest val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words") checkAnswer( - df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word), + df.explode("words", "word") { word: String => word.split(" ").toSeq }.select(Symbol("word")), Row("a") :: Row("b") :: Row("c") :: Row("d") ::Row("e") :: Nil ) } @@ -125,15 +125,15 @@ class DataFrameSuite extends QueryTest test("explode") { val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters") val df2 = - df.explode('letters) { + df.explode(Symbol("letters")) { case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq } checkAnswer( df2 - .select('_1 as 'letter, 'number) - .groupBy('letter) - .agg(count_distinct('number)), + .select(Symbol("_1") as Symbol("letter"), Symbol("number")) + .groupBy(Symbol("letter")) + .agg(count_distinct(Symbol("number"))), Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil ) } @@ -314,7 +314,7 @@ class DataFrameSuite extends QueryTest assert(e.getMessage.contains("Invalid usage of '*' in explode/json_tuple/UDTF")) checkAnswer( - df.explode('prefix, 'csv) { case Row(prefix: String, csv: String) => + df.explode(Symbol("prefix"), Symbol("csv")) { case Row(prefix: String, csv: String) => csv.split(",").map(v => Tuple1(prefix + ":" + v)).toSeq }, Row("1", "1,2", "1:1") :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedAPISuite.scala index 25b8849d6124..cbcb591cf571 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedAPISuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedAPISuite.scala @@ -33,12 +33,12 @@ class DeprecatedAPISuite extends QueryTest with SharedSparkSession { c: Column => Column, f: T => U): Unit = { checkAnswer( - doubleData.select(c('a)), + doubleData.select(c(Symbol("a"))), (1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T]))) ) checkAnswer( - doubleData.select(c('b)), + doubleData.select(c(Symbol("b"))), (1 to 10).map(n => Row(f((-n * 0.2 + 1).asInstanceOf[T]))) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 876f62803dc7..62f5bed1e082 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -885,51 +885,82 @@ class FileBasedDataSourceSuite extends QueryTest // cases when value == MAX var v = Short.MaxValue - checkPushedFilters(format, df.where('id > v.toInt), Array(), noScan = true) - checkPushedFilters(format, df.where('id >= v.toInt), Array(sources.IsNotNull("id"), + checkPushedFilters(format, df.where(Symbol("id") > v.toInt), Array(), noScan = true) + checkPushedFilters( + format, + df.where(Symbol("id") >= v.toInt), + Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) + checkPushedFilters( + format, + df.where(Symbol("id") === v.toInt), + Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) - checkPushedFilters(format, df.where('id === v.toInt), Array(sources.IsNotNull("id"), - sources.EqualTo("id", v))) - checkPushedFilters(format, df.where('id <=> v.toInt), + checkPushedFilters(format, df.where(Symbol("id") <=> v.toInt), Array(sources.EqualNullSafe("id", v))) - checkPushedFilters(format, df.where('id <= v.toInt), Array(sources.IsNotNull("id"))) - checkPushedFilters(format, df.where('id < v.toInt), Array(sources.IsNotNull("id"), + checkPushedFilters( + format, + df.where(Symbol("id") <= v.toInt), + Array(sources.IsNotNull("id"))) + checkPushedFilters( + format, + df.where(Symbol("id") < v.toInt), + Array(sources.IsNotNull("id"), sources.Not(sources.EqualTo("id", v)))) // cases when value > MAX var v1: Int = positiveInt - checkPushedFilters(format, df.where('id > v1), Array(), noScan = true) - checkPushedFilters(format, df.where('id >= v1), Array(), noScan = true) - checkPushedFilters(format, df.where('id === v1), Array(), noScan = true) - checkPushedFilters(format, df.where('id <=> v1), Array(), noScan = true) - checkPushedFilters(format, df.where('id <= v1), Array(sources.IsNotNull("id"))) - checkPushedFilters(format, df.where('id < v1), Array(sources.IsNotNull("id"))) + checkPushedFilters(format, df.where(Symbol("id") > v1), Array(), noScan = true) + checkPushedFilters(format, df.where(Symbol("id") >= v1), Array(), noScan = true) + checkPushedFilters(format, df.where(Symbol("id") === v1), Array(), noScan = true) + checkPushedFilters(format, df.where(Symbol("id") <=> v1), Array(), noScan = true) + checkPushedFilters(format, df.where(Symbol("id") <= v1), Array(sources.IsNotNull("id"))) + checkPushedFilters(format, df.where(Symbol("id") < v1), Array(sources.IsNotNull("id"))) // cases when value = MIN v = Short.MinValue - checkPushedFilters(format, df.where(lit(v.toInt) < 'id), Array(sources.IsNotNull("id"), + checkPushedFilters( + format, df.where(lit(v.toInt) < Symbol("id")), + Array(sources.IsNotNull("id"), sources.Not(sources.EqualTo("id", v)))) - checkPushedFilters(format, df.where(lit(v.toInt) <= 'id), Array(sources.IsNotNull("id"))) - checkPushedFilters(format, df.where(lit(v.toInt) === 'id), Array(sources.IsNotNull("id"), + checkPushedFilters( + format, + df.where(lit(v.toInt) <= Symbol("id")), + Array(sources.IsNotNull("id"))) + checkPushedFilters( + format, + df.where(lit(v.toInt) === Symbol("id")), + Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) - checkPushedFilters(format, df.where(lit(v.toInt) <=> 'id), + checkPushedFilters(format, df.where(lit(v.toInt) <=> Symbol("id")), Array(sources.EqualNullSafe("id", v))) - checkPushedFilters(format, df.where(lit(v.toInt) >= 'id), Array(sources.IsNotNull("id"), + checkPushedFilters( + format, + df.where(lit(v.toInt) >= Symbol("id")), + Array(sources.IsNotNull("id"), sources.EqualTo("id", v))) - checkPushedFilters(format, df.where(lit(v.toInt) > 'id), Array(), noScan = true) + checkPushedFilters(format, df.where(lit(v.toInt) > Symbol("id")), Array(), noScan = true) // cases when value < MIN v1 = negativeInt - checkPushedFilters(format, df.where(lit(v1) < 'id), Array(sources.IsNotNull("id"))) - checkPushedFilters(format, df.where(lit(v1) <= 'id), Array(sources.IsNotNull("id"))) - checkPushedFilters(format, df.where(lit(v1) === 'id), Array(), noScan = true) - checkPushedFilters(format, df.where(lit(v1) >= 'id), Array(), noScan = true) - checkPushedFilters(format, df.where(lit(v1) > 'id), Array(), noScan = true) + checkPushedFilters( + format, + df.where(lit(v1) < Symbol("id")), + Array(sources.IsNotNull("id"))) + checkPushedFilters( + format, + df.where(lit(v1) <= Symbol("id")), + Array(sources.IsNotNull("id"))) + checkPushedFilters(format, df.where(lit(v1) === Symbol("id")), Array(), noScan = true) + checkPushedFilters(format, df.where(lit(v1) >= Symbol("id")), Array(), noScan = true) + checkPushedFilters(format, df.where(lit(v1) > Symbol("id")), Array(), noScan = true) // cases when value is within range (MIN, MAX) - checkPushedFilters(format, df.where('id > 30), Array(sources.IsNotNull("id"), + checkPushedFilters(format, df.where(Symbol("id") > 30), Array(sources.IsNotNull("id"), sources.GreaterThan("id", 30))) - checkPushedFilters(format, df.where(lit(100) >= 'id), Array(sources.IsNotNull("id"), + checkPushedFilters( + format, + df.where(lit(100) >= Symbol("id")), + Array(sources.IsNotNull("id"), sources.LessThanOrEqual("id", 100))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 2e336b264cd3..16ec05a61d06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -183,7 +183,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan test("inner join where, one match per row") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { checkAnswer( - upperCaseData.join(lowerCaseData).where('n === 'N), + upperCaseData.join(lowerCaseData).where(Symbol("n") === Symbol("N")), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -404,8 +404,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan test("full outer join") { withTempView("`left`", "`right`") { - upperCaseData.where('N <= 4).createOrReplaceTempView("`left`") - upperCaseData.where('N >= 3).createOrReplaceTempView("`right`") + upperCaseData.where(Symbol("N") <= 4).createOrReplaceTempView("`left`") + upperCaseData.where(Symbol("N") >= 3).createOrReplaceTempView("`right`") val left = UnresolvedRelation(TableIdentifier("left")) val right = UnresolvedRelation(TableIdentifier("right")) @@ -623,7 +623,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan testData.createOrReplaceTempView("B") testData2.createOrReplaceTempView("C") testData3.createOrReplaceTempView("D") - upperCaseData.where('N >= 3).createOrReplaceTempView("`right`") + upperCaseData.where(Symbol("N") >= 3).createOrReplaceTempView("`right`") val cartesianQueries = Seq( /** The following should error out since there is no explicit cross join */ "SELECT * FROM testData inner join testData2", @@ -1097,7 +1097,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan } test("SPARK-29850: sort-merge-join an empty table should not memory leak") { - val df1 = spark.range(10).select($"id", $"id" % 3 as 'p) + val df1 = spark.range(10).select($"id", $"id" % 3 as Symbol("p")) .repartition($"id").groupBy($"id").agg(Map("p" -> "max")) val df2 = spark.range(0) withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 310e170e8c1b..e5724fcb0c6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -397,7 +397,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { test("SPARK-24709: infers schemas of json strings and pass them to from_json") { val in = Seq("""{"a": [1, 2, 3]}""").toDS() - val out = in.select(from_json('value, schema_of_json("""{"a": [1]}""")) as "parsed") + val out = in.select(from_json(Symbol("value"), schema_of_json("""{"a": [1]}""")) as "parsed") val expected = StructType(StructField( "parsed", StructType(StructField( @@ -659,8 +659,8 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { Seq(2000, 2800, 8000 - 1, 8000, 8000 + 1, 65535).foreach { len => val str = Array.tabulate(len)(_ => "a").mkString val json_tuple_result = Seq(s"""{"test":"$str"}""").toDF("json") - .withColumn("result", json_tuple('json, "test")) - .select('result) + .withColumn("result", json_tuple(Symbol("json"), "test")) + .select(Symbol("result")) .as[String].head.length assert(json_tuple_result === len) } @@ -750,7 +750,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { test("SPARK-33270: infers schema for JSON field with spaces and pass them to from_json") { val in = Seq("""{"a b": 1}""").toDS() - val out = in.select(from_json('value, schema_of_json("""{"a b": 100}""")) as "parsed") + val out = in.select(from_json(Symbol("value"), schema_of_json("""{"a b": 100}""")) as "parsed") val expected = new StructType().add("parsed", new StructType().add("a b", LongType)) assert(out.schema == expected) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index 3509804bdeb6..e2adfd9ffe62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -46,12 +46,12 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { c: Column => Column, f: T => U): Unit = { checkAnswer( - doubleData.select(c('a)), + doubleData.select(c(Symbol("a"))), (1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T]))) ) checkAnswer( - doubleData.select(c('b)), + doubleData.select(c(Symbol("b"))), (1 to 10).map(n => Row(f((-n * 0.2 + 1).asInstanceOf[T]))) ) @@ -64,13 +64,13 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { private def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { checkAnswer( - nnDoubleData.select(c('a)), + nnDoubleData.select(c(Symbol("a"))), (1 to 10).map(n => Row(f(n * 0.1))) ) if (f(-1) === StrictMath.log1p(-1)) { checkAnswer( - nnDoubleData.select(c('b)), + nnDoubleData.select(c(Symbol("b"))), (1 to 9).map(n => Row(f(n * -0.1))) :+ Row(null) ) } @@ -86,29 +86,29 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { d: (Column, Double) => Column, f: (Double, Double) => Double): Unit = { checkAnswer( - nnDoubleData.select(c('a, 'a)), + nnDoubleData.select(c(Symbol("a"), Symbol("a"))), nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) ) checkAnswer( - nnDoubleData.select(c('a, 'b)), + nnDoubleData.select(c(Symbol("a"), Symbol("b"))), nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1)))) ) checkAnswer( - nnDoubleData.select(d('a, 2.0)), + nnDoubleData.select(d(Symbol("a"), 2.0)), nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), 2.0))) ) checkAnswer( - nnDoubleData.select(d('a, -0.5)), + nnDoubleData.select(d(Symbol("a"), -0.5)), nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), -0.5))) ) val nonNull = nullDoubles.collect().toSeq.filter(r => r.get(0) != null) checkAnswer( - nullDoubles.select(c('a, 'a)).orderBy('a.asc), + nullDoubles.select(c(Symbol("a"), Symbol("a"))).orderBy(Symbol("a").asc), Row(null) +: nonNull.map(r => Row(f(r.getDouble(0), r.getDouble(0)))) ) } @@ -193,7 +193,7 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { test("conv") { val df = Seq(("333", 10, 2)).toDF("num", "fromBase", "toBase") - checkAnswer(df.select(conv('num, 10, 16)), Row("14D")) + checkAnswer(df.select(conv(Symbol("num"), 10, 16)), Row("14D")) checkAnswer(df.select(conv(lit(100), 2, 16)), Row("4")) checkAnswer(df.select(conv(lit(3122234455L), 10, 16)), Row("BA198457")) checkAnswer(df.selectExpr("conv(num, fromBase, toBase)"), Row("101001101")) @@ -210,7 +210,7 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { test("factorial") { val df = (0 to 5).map(i => (i, i)).toDF("a", "b") checkAnswer( - df.select(factorial('a)), + df.select(factorial(Symbol("a"))), Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120)) ) checkAnswer( @@ -226,11 +226,11 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { test("round/bround") { val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a") checkAnswer( - df.select(round('a), round('a, -1), round('a, -2)), + df.select(round(Symbol("a")), round(Symbol("a"), -1), round(Symbol("a"), -2)), Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) ) checkAnswer( - df.select(bround('a), bround('a, -1), bround('a, -2)), + df.select(bround(Symbol("a")), bround(Symbol("a"), -1), bround(Symbol("a"), -2)), Seq(Row(5, 0, 0), Row(55, 60, 100), Row(555, 560, 600)) ) @@ -267,11 +267,11 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { test("round/bround with data frame from a local Seq of Product") { val df = spark.createDataFrame(Seq(Tuple1(BigDecimal("5.9")))).toDF("value") checkAnswer( - df.withColumn("value_rounded", round('value)), + df.withColumn("value_rounded", round(Symbol("value"))), Seq(Row(BigDecimal("5.9"), BigDecimal("6"))) ) checkAnswer( - df.withColumn("value_brounded", bround('value)), + df.withColumn("value_brounded", bround(Symbol("value"))), Seq(Row(BigDecimal("5.9"), BigDecimal("6"))) ) } @@ -315,10 +315,10 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { test("hex") { val data = Seq((28, -28, 100800200404L, "hello")).toDF("a", "b", "c", "d") - checkAnswer(data.select(hex('a)), Seq(Row("1C"))) - checkAnswer(data.select(hex('b)), Seq(Row("FFFFFFFFFFFFFFE4"))) - checkAnswer(data.select(hex('c)), Seq(Row("177828FED4"))) - checkAnswer(data.select(hex('d)), Seq(Row("68656C6C6F"))) + checkAnswer(data.select(hex(Symbol("a"))), Seq(Row("1C"))) + checkAnswer(data.select(hex(Symbol("b"))), Seq(Row("FFFFFFFFFFFFFFE4"))) + checkAnswer(data.select(hex(Symbol("c"))), Seq(Row("177828FED4"))) + checkAnswer(data.select(hex(Symbol("d"))), Seq(Row("68656C6C6F"))) checkAnswer(data.selectExpr("hex(a)"), Seq(Row("1C"))) checkAnswer(data.selectExpr("hex(b)"), Seq(Row("FFFFFFFFFFFFFFE4"))) checkAnswer(data.selectExpr("hex(c)"), Seq(Row("177828FED4"))) @@ -328,8 +328,8 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { test("unhex") { val data = Seq(("1C", "737472696E67")).toDF("a", "b") - checkAnswer(data.select(unhex('a)), Row(Array[Byte](28.toByte))) - checkAnswer(data.select(unhex('b)), Row("string".getBytes(StandardCharsets.UTF_8))) + checkAnswer(data.select(unhex(Symbol("a"))), Row(Array[Byte](28.toByte))) + checkAnswer(data.select(unhex(Symbol("b"))), Row("string".getBytes(StandardCharsets.UTF_8))) checkAnswer(data.selectExpr("unhex(a)"), Row(Array[Byte](28.toByte))) checkAnswer(data.selectExpr("unhex(b)"), Row("string".getBytes(StandardCharsets.UTF_8))) checkAnswer(data.selectExpr("""unhex("##")"""), Row(null)) @@ -366,8 +366,8 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer( df.select( - shiftleft('a, 1), shiftleft('b, 1), shiftleft('c, 1), shiftleft('d, 1), - shiftLeft('f, 1)), // test deprecated one. + shiftleft(Symbol("a"), 1), shiftleft(Symbol("b"), 1), shiftleft(Symbol("c"), 1), + shiftleft(Symbol("d"), 1), shiftLeft(Symbol("f"), 1)), // test deprecated one. Row(42.toLong, 42, 42.toShort, 42.toByte, null)) checkAnswer( @@ -383,8 +383,8 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer( df.select( - shiftright('a, 1), shiftright('b, 1), shiftright('c, 1), shiftright('d, 1), - shiftRight('f, 1)), // test deprecated one. + shiftright(Symbol("a"), 1), shiftright(Symbol("b"), 1), shiftright(Symbol("c"), 1), + shiftright(Symbol("d"), 1), shiftRight(Symbol("f"), 1)), // test deprecated one. Row(21.toLong, 21, 21.toShort, 21.toByte, null)) checkAnswer( @@ -400,8 +400,9 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer( df.select( - shiftrightunsigned('a, 1), shiftrightunsigned('b, 1), shiftrightunsigned('c, 1), - shiftrightunsigned('d, 1), shiftRightUnsigned('f, 1)), // test deprecated one. + shiftrightunsigned(Symbol("a"), 1), shiftrightunsigned(Symbol("b"), 1), + shiftrightunsigned(Symbol("c"), 1), shiftrightunsigned(Symbol("d"), 1), + shiftRightUnsigned(Symbol("f"), 1)), // test deprecated one. Row(9223372036854775787L, 21, 21.toShort, 21.toByte, null)) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 31b9427d1bcf..41b75ba9e594 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3110,15 +3110,16 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark val df = spark.read.format(format).load(dir.getCanonicalPath) checkPushedFilters( format, - df.where(('id < 2 and 's.contains("foo")) or ('id > 10 and 's.contains("bar"))), + df.where((Symbol("id") < 2 and Symbol("s").contains("foo")) or (Symbol("id") > 10 and + Symbol("s").contains("bar"))), Array(sources.Or(sources.LessThan("id", 2), sources.GreaterThan("id", 10)))) checkPushedFilters( format, - df.where('s.contains("foo") or ('id > 10 and 's.contains("bar"))), - Array.empty) + df.where(Symbol("s").contains("foo") or (Symbol("id") > 10 and + Symbol("s").contains("bar"))), Array.empty) checkPushedFilters( format, - df.where('id < 2 and not('id > 10 and 's.contains("bar"))), + df.where(Symbol("id") < 2 and not(Symbol("id") > 10 and Symbol("s").contains("bar"))), Array(sources.IsNotNull("id"), sources.LessThan("id", 2))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala index 5ce5d36c5e8f..1f4172fe358c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ShowCreateTableSuite.scala @@ -137,7 +137,8 @@ abstract class ShowCreateTableSuite extends QueryTest with SQLTestUtils { withTable("ddl_test") { spark .range(3) - .select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd, 'id as 'e) + .select(Symbol("id") as Symbol("a"), Symbol("id") as Symbol("b"), + Symbol("id") as Symbol("c"), Symbol("id") as Symbol("d"), Symbol("id") as Symbol("e")) .write .mode("overwrite") .partitionBy("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 87e7641c87f6..1a3369021dee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -405,10 +405,10 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared withSQLConf(SQLConf.CBO_ENABLED.key -> "true") { withTable("TBL1", "TBL") { import org.apache.spark.sql.functions._ - val df = spark.range(1000L).select('id, - 'id * 2 as "FLD1", - 'id * 12 as "FLD2", - lit("aaa") + 'id as "fld3") + val df = spark.range(1000L).select(Symbol("id"), + Symbol("id") * 2 as "FLD1", + Symbol("id") * 12 as "FLD2", + lit("aaa") + Symbol("id") as "fld3") df.write .mode(SaveMode.Overwrite) .bucketBy(10, "id", "FLD1", "FLD2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 7d3faaef2cd4..39b31baf733c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -422,7 +422,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { ("N", Integer.valueOf(3), null)).toDF("a", "b", "c") val udf1 = udf((a: String, b: Int, c: Any) => a + b + c) - val df = input.select(udf1('a, 'b, 'c)) + val df = input.select(udf1(Symbol("a"), Symbol("b"), Symbol("c"))) checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null"))) // test Java UDF. Java UDF can't have primitive inputs, as it's generic typed. @@ -431,7 +431,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { t1 + t2 + t3 } }, StringType) - val df2 = input.select(udf2('a, 'b, 'c)) + val df2 = input.select(udf2(Symbol("a"), Symbol("b"), Symbol("c"))) checkAnswer(df2, Seq(Row("null1x"), Row("Mnully"), Row("N3null"))) } @@ -525,7 +525,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { .format(dtf) val plusSec = udf((i: java.time.Instant) => i.plusSeconds(1)) val df = spark.sql("SELECT TIMESTAMP '2019-02-26 23:59:59Z' as t") - .select(plusSec('t).cast(StringType)) + .select(plusSec(Symbol("t")).cast(StringType)) checkAnswer(df, Row(expected) :: Nil) } @@ -533,7 +533,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { val expected = java.time.LocalDate.parse("2019-02-27").toString val plusDay = udf((i: java.time.LocalDate) => i.plusDays(1)) val df = spark.sql("SELECT DATE '2019-02-26' as d") - .select(plusDay('d).cast(StringType)) + .select(plusDay(Symbol("d")).cast(StringType)) checkAnswer(df, Row(expected) :: Nil) } @@ -552,7 +552,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { spark.udf.register("buildLocalDateInstantType", udf((d: LocalDate, i: Instant) => LocalDateInstantType(d, i))) checkAnswer(df.selectExpr(s"buildLocalDateInstantType(d, i) as di") - .select('di.cast(StringType)), + .select(Symbol("di").cast(StringType)), Row(s"{$expectedDate, $expectedInstant}") :: Nil) // test null cases @@ -582,7 +582,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { spark.udf.register("buildTimestampInstantType", udf((t: Timestamp, i: Instant) => TimestampInstantType(t, i))) checkAnswer(df.selectExpr("buildTimestampInstantType(t, i) as ti") - .select('ti.cast(StringType)), + .select(Symbol("ti").cast(StringType)), Row(s"{$expectedTimestamp, $expectedInstant}")) // test null cases @@ -601,11 +601,11 @@ class UDFSuite extends QueryTest with SharedSparkSession { // without explicit type val udf1 = udf((i: String) => null) assert(udf1.asInstanceOf[SparkUserDefinedFunction] .dataType === NullType) - checkAnswer(Seq("1").toDF("a").select(udf1('a)), Row(null) :: Nil) + checkAnswer(Seq("1").toDF("a").select(udf1(Symbol("a"))), Row(null) :: Nil) // with explicit type val udf2 = udf((i: String) => null.asInstanceOf[String]) assert(udf2.asInstanceOf[SparkUserDefinedFunction].dataType === StringType) - checkAnswer(Seq("1").toDF("a").select(udf1('a)), Row(null) :: Nil) + checkAnswer(Seq("1").toDF("a").select(udf1(Symbol("a"))), Row(null) :: Nil) } test("SPARK-28321 0-args Java UDF should not be called only once") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index b6ab60a91955..68e15d3ab312 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -82,14 +82,14 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque } test("register user type: MyDenseVector for MyLabeledPoint") { - val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } + val labels: RDD[Double] = pointsRDD.select(Symbol("label")).rdd.map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() assert(labelsArrays.size === 2) assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) val features: RDD[TestUDT.MyDenseVector] = - pointsRDD.select('features).rdd.map { case Row(v: TestUDT.MyDenseVector) => v } + pointsRDD.select(Symbol("features")).rdd.map { case Row(v: TestUDT.MyDenseVector) => v } val featuresArrays: Array[TestUDT.MyDenseVector] = features.collect() assert(featuresArrays.size === 2) assert(featuresArrays.contains(new TestUDT.MyDenseVector(Array(0.1, 1.0)))) @@ -137,8 +137,9 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque val df = Seq((1, vec)).toDF("int", "vec") assert(vec === df.collect()(0).getAs[TestUDT.MyDenseVector](1)) assert(vec === df.take(1)(0).getAs[TestUDT.MyDenseVector](1)) - checkAnswer(df.limit(1).groupBy('int).agg(first('vec)), Row(1, vec)) - checkAnswer(df.orderBy('int).limit(1).groupBy('int).agg(first('vec)), Row(1, vec)) + checkAnswer(df.limit(1).groupBy(Symbol("int")).agg(first(Symbol("vec"))), Row(1, vec)) + checkAnswer(df.orderBy(Symbol("int")).limit(1) + .groupBy(Symbol("int")).agg(first(Symbol("vec"))), Row(1, vec)) } test("UDTs with JSON") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala index 46112d40f08b..aa64ee26f475 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -209,7 +209,7 @@ private [connector] trait SessionCatalogTest[T <: Table, Catalog <: TestV2Sessio verifyTable(t1, df) // Check that appends are by name - df.select('data, 'id).write.format(v2Format).mode("append").saveAsTable(t1) + df.select(Symbol("data"), Symbol("id")).write.format(v2Format).mode("append").saveAsTable(t1) verifyTable(t1, df.union(df)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index efb87dafe0ff..d0948d950230 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -92,7 +92,7 @@ class DataSourceV2DataFrameSuite assert(spark.table(t1).count() === 0) // appends are by name not by position - df.select('data, 'id).write.mode("append").saveAsTable(t1) + df.select(Symbol("data"), Symbol("id")).write.mode("append").saveAsTable(t1) checkAnswer(spark.table(t1), df) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index 28cb448c400c..79d6eef80965 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -66,8 +66,8 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS withClue(cls.getName) { val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 10).map(i => Row(i, -i))) - checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) - checkAnswer(df.filter('i > 5), (6 until 10).map(i => Row(i, -i))) + checkAnswer(df.select(Symbol("j")), (0 until 10).map(i => Row(-i))) + checkAnswer(df.filter(Symbol("i") > 5), (6 until 10).map(i => Row(i, -i))) } } } @@ -78,7 +78,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 10).map(i => Row(i, -i))) - val q1 = df.select('j) + val q1 = df.select(Symbol("j")) checkAnswer(q1, (0 until 10).map(i => Row(-i))) if (cls == classOf[AdvancedDataSourceV2]) { val batch = getBatch(q1) @@ -90,7 +90,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS assert(batch.requiredSchema.fieldNames === Seq("j")) } - val q2 = df.filter('i > 3) + val q2 = df.filter(Symbol("i") > 3) checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) if (cls == classOf[AdvancedDataSourceV2]) { val batch = getBatch(q2) @@ -102,7 +102,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS assert(batch.requiredSchema.fieldNames === Seq("i", "j")) } - val q3 = df.select('i).filter('i > 6) + val q3 = df.select(Symbol("i")).filter(Symbol("i") > 6) checkAnswer(q3, (7 until 10).map(i => Row(i))) if (cls == classOf[AdvancedDataSourceV2]) { val batch = getBatch(q3) @@ -114,16 +114,16 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS assert(batch.requiredSchema.fieldNames === Seq("i")) } - val q4 = df.select('j).filter('j < -10) + val q4 = df.select(Symbol("j")).filter(Symbol("j") < -10) checkAnswer(q4, Nil) if (cls == classOf[AdvancedDataSourceV2]) { val batch = getBatch(q4) - // 'j < 10 is not supported by the testing data source. + // Symbol("j") < 10 is not supported by the testing data source. assert(batch.filters.isEmpty) assert(batch.requiredSchema.fieldNames === Seq("j")) } else { val batch = getJavaBatch(q4) - // 'j < 10 is not supported by the testing data source. + // Symbol("j") < 10 is not supported by the testing data source. assert(batch.filters.isEmpty) assert(batch.requiredSchema.fieldNames === Seq("j")) } @@ -136,8 +136,8 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS withClue(cls.getName) { val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 90).map(i => Row(i, -i))) - checkAnswer(df.select('j), (0 until 90).map(i => Row(-i))) - checkAnswer(df.filter('i > 50), (51 until 90).map(i => Row(i, -i))) + checkAnswer(df.select(Symbol("j")), (0 until 90).map(i => Row(-i))) + checkAnswer(df.filter(Symbol("i") > 50), (51 until 90).map(i => Row(i, -i))) } } } @@ -161,12 +161,12 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS "supports external metadata") { withTempDir { dir => val cls = classOf[SupportsExternalMetadataWritableDataSource].getName - spark.range(10).select('id as 'i, -'id as 'j).write.format(cls) - .option("path", dir.getCanonicalPath).mode("append").save() + spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.format(cls).option("path", dir.getCanonicalPath).mode("append").save() val schema = new StructType().add("i", "long").add("j", "long") checkAnswer( spark.read.format(cls).option("path", dir.getCanonicalPath).schema(schema).load(), - spark.range(10).select('id, -'id)) + spark.range(10).select(Symbol("id"), -Symbol("id"))) } } @@ -177,25 +177,25 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS val df = spark.read.format(cls.getName).load() checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) - val groupByColA = df.groupBy('i).agg(sum('j)) + val groupByColA = df.groupBy(Symbol("i")).agg(sum(Symbol("j"))) checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) assert(collectFirst(groupByColA.queryExecution.executedPlan) { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColAB = df.groupBy('i, 'j).agg(count("*")) + val groupByColAB = df.groupBy(Symbol("i"), Symbol("j")).agg(count("*")) checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) assert(collectFirst(groupByColAB.queryExecution.executedPlan) { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColB = df.groupBy('j).agg(sum('i)) + val groupByColB = df.groupBy(Symbol("j")).agg(sum(Symbol("i"))) checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) assert(collectFirst(groupByColB.queryExecution.executedPlan) { case e: ShuffleExchangeExec => e }.isDefined) - val groupByAPlusB = df.groupBy('i + 'j).agg(count("*")) + val groupByAPlusB = df.groupBy(Symbol("i") + Symbol("j")).agg(count("*")) checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) assert(collectFirst(groupByAPlusB.queryExecution.executedPlan) { case e: ShuffleExchangeExec => e @@ -234,38 +234,38 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS val path = file.getCanonicalPath assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) - spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) - .option("path", path).mode("append").save() + spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.format(cls.getName).option("path", path).mode("append").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), - spark.range(10).select('id, -'id)) + spark.range(10).select(Symbol("id"), -Symbol("id"))) // default save mode is ErrorIfExists intercept[AnalysisException] { - spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) - .option("path", path).save() + spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.format(cls.getName).option("path", path).save() } - spark.range(10).select('id as 'i, -'id as 'j).write.mode("append").format(cls.getName) - .option("path", path).save() + spark.range(10).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.mode("append").format(cls.getName).option("path", path).save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), - spark.range(10).union(spark.range(10)).select('id, -'id)) + spark.range(10).union(spark.range(10)).select(Symbol("id"), -Symbol("id"))) - spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) - .option("path", path).mode("overwrite").save() + spark.range(5).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.format(cls.getName).option("path", path).mode("overwrite").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), - spark.range(5).select('id, -'id)) + spark.range(5).select(Symbol("id"), -Symbol("id"))) val e = intercept[AnalysisException] { - spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) - .option("path", path).mode("ignore").save() + spark.range(5).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.format(cls.getName).option("path", path).mode("ignore").save() } assert(e.message.contains("please use Append or Overwrite modes instead")) val e2 = intercept[AnalysisException] { - spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) - .option("path", path).mode("error").save() + spark.range(5).select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.format(cls.getName).option("path", path).mode("error").save() } assert(e2.getMessage.contains("please use Append or Overwrite modes instead")) @@ -281,7 +281,8 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS } } // this input data will fail to read middle way. - val input = spark.range(15).select(failingUdf('id).as('i)).select('i, -'i as 'j) + val input = spark.range(15).select(failingUdf(Symbol("id")).as(Symbol("i"))) + .select(Symbol("i"), -Symbol("i") as Symbol("j")) val e3 = intercept[SparkException] { input.write.format(cls.getName).option("path", path).mode("overwrite").save() } @@ -299,11 +300,12 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) val numPartition = 6 - spark.range(0, 10, 1, numPartition).select('id as 'i, -'id as 'j).write.format(cls.getName) - .mode("append").option("path", path).save() + spark.range(0, 10, 1, numPartition) + .select(Symbol("id") as Symbol("i"), -Symbol("id") as Symbol("j")) + .write.format(cls.getName).mode("append").option("path", path).save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), - spark.range(10).select('id, -'id)) + spark.range(10).select(Symbol("id"), -Symbol("id"))) assert(SimpleCounter.getCounter == numPartition, "method onDataWriterCommit should be called as many as the number of partitions") @@ -320,7 +322,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS test("SPARK-23301: column pruning with arbitrary expressions") { val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() - val q1 = df.select('i + 1) + val q1 = df.select(Symbol("i") + 1) checkAnswer(q1, (1 until 11).map(i => Row(i))) val batch1 = getBatch(q1) assert(batch1.requiredSchema.fieldNames === Seq("i")) @@ -330,15 +332,15 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS val batch2 = getBatch(q2) assert(batch2.requiredSchema.isEmpty) - // 'j === 1 can't be pushed down, but we should still be able do column pruning - val q3 = df.filter('j === -1).select('j * 2) + // Symbol("j") === 1 can't be pushed down, but we should still be able do column pruning + val q3 = df.filter(Symbol("j") === -1).select(Symbol("j") * 2) checkAnswer(q3, Row(-2)) val batch3 = getBatch(q3) assert(batch3.filters.isEmpty) assert(batch3.requiredSchema.fieldNames === Seq("j")) // column pruning should work with other operators. - val q4 = df.sort('i).limit(1).select('i + 1) + val q4 = df.sort(Symbol("i")).limit(1).select(Symbol("i") + 1) checkAnswer(q4, Row(1)) val batch4 = getBatch(q4) assert(batch4.requiredSchema.fieldNames === Seq("i")) @@ -360,7 +362,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() checkCanonicalizedOutput(df, 2, 2) - checkCanonicalizedOutput(df.select('i), 2, 1) + checkCanonicalizedOutput(df.select(Symbol("i")), 2, 1) } test("SPARK-25425: extra options should override sessions options during reading") { @@ -399,7 +401,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS withTempView("t1") { val t2 = spark.read.format(classOf[SimpleDataSourceV2].getName).load() Seq(2, 3).toDF("a").createTempView("t1") - val df = t2.where("i < (select max(a) from t1)").select('i) + val df = t2.where("i < (select max(a) from t1)").select(Symbol("i")) val subqueries = stripAQEPlan(df.queryExecution.executedPlan).collect { case p => p.subqueries }.flatten @@ -418,8 +420,8 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() - val q1 = df.select('i).filter('i > 6) - val q2 = df.select('i).filter('i > 5) + val q1 = df.select(Symbol("i")).filter(Symbol("i") > 6) + val q2 = df.select(Symbol("i")).filter(Symbol("i") > 5) val scan1 = getScanExec(q1) val scan2 = getScanExec(q2) assert(!scan1.equals(scan2)) @@ -432,7 +434,7 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS withClue(cls.getName) { val df = spark.read.format(cls.getName).load() // before SPARK-33267 below query just threw NPE - df.select('i).where("i in (1, null)").collect() + df.select(Symbol("i")).where("i in (1, null)").collect() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index 3aad644655aa..78ba85ad2d67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -71,7 +71,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with saveMode: SaveMode, withCatalogOption: Option[String], partitionBy: Seq[String]): Unit = { - val df = spark.range(10).withColumn("part", 'id % 5) + val df = spark.range(10).withColumn("part", Symbol("id") % 5) val dfw = df.write.format(format).mode(saveMode).option("name", "t1") withCatalogOption.foreach(cName => dfw.option("catalog", cName)) dfw.partitionBy(partitionBy: _*).save() @@ -136,7 +136,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with test("Ignore mode if table exists - session catalog") { sql(s"create table t1 (id bigint) using $format") - val df = spark.range(10).withColumn("part", 'id % 5) + val df = spark.range(10).withColumn("part", Symbol("id") % 5) val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") dfw.save() @@ -148,7 +148,7 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with test("Ignore mode if table exists - testcat catalog") { sql(s"create table $catalogName.t1 (id bigint) using $format") - val df = spark.range(10).withColumn("part", 'id % 5) + val df = spark.range(10).withColumn("part", Symbol("id") % 5) val dfw = df.write.format(format).mode(SaveMode.Ignore).option("name", "t1") dfw.option("catalog", catalogName).save() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala index a33b9fad7ff4..06fc2022c01a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/AggregatingAccumulatorSuite.scala @@ -35,9 +35,9 @@ class AggregatingAccumulatorSuite extends SparkFunSuite with SharedSparkSession with ExpressionEvalHelper { - private val a = 'a.long - private val b = 'b.string - private val c = 'c.double + private val a = Symbol("a").long + private val b = Symbol("b").string + private val c = Symbol("c").double private val inputAttributes = Seq(a, b, c) private def str(s: String): UTF8String = UTF8String.fromString(s) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala index cef870b24998..2dcc3881a7bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BaseScriptTransformationSuite.scala @@ -133,11 +133,11 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU """.stripMargin) checkAnswer(query, identity, df.select( - 'a.cast("string"), - 'b.cast("string"), - 'c.cast("string"), - 'd.cast("string"), - 'e.cast("string")).collect()) + Symbol("a").cast("string"), + Symbol("b").cast("string"), + Symbol("c").cast("string"), + Symbol("d").cast("string"), + Symbol("e").cast("string")).collect()) } } @@ -166,8 +166,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU ioschema = defaultIOSchema.copy(schemaLess = true) ), df.select( - 'a.cast("string").as("key"), - 'b.cast("string").as("value")).collect()) + Symbol("a").cast("string").as("key"), + Symbol("b").cast("string").as("value")).collect()) checkAnswer( df, @@ -183,8 +183,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU ioschema = defaultIOSchema.copy(schemaLess = true) ), df.select( - 'a.cast("string").as("key"), - 'b.cast("string").as("value")).collect()) + Symbol("a").cast("string").as("key"), + Symbol("b").cast("string").as("value")).collect()) checkAnswer( df, @@ -199,7 +199,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU ioschema = defaultIOSchema.copy(schemaLess = true) ), df.select( - 'a.cast("string").as("key"), + Symbol("a").cast("string").as("key"), lit(null)).collect()) } } @@ -265,7 +265,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU child = child, ioschema = serde ), - df.select('a, 'b, 'c, 'd, 'e, 'f, 'g, 'h, 'i, 'j).collect()) + df.select(Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), Symbol("e"), Symbol("f"), + Symbol("g"), Symbol("h"), Symbol("i"), Symbol("j")).collect()) } } } @@ -309,7 +310,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU child = child, ioschema = defaultIOSchema ), - df.select('a, 'b.cast("string"), 'c.cast("string"), 'd.cast("string"), 'e).collect()) + df.select(Symbol("a"), Symbol("b").cast("string"), Symbol("c").cast("string"), + Symbol("d").cast("string"), Symbol("e")).collect()) } } @@ -331,7 +333,7 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU |USING 'cat' AS (a timestamp, b date) |FROM v """.stripMargin) - checkAnswer(query, identity, df.select('a, 'b).collect()) + checkAnswer(query, identity, df.select(Symbol("a"), Symbol("b")).collect()) } } } @@ -366,11 +368,11 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU | NULL DEFINED AS 'NULL' |FROM v """.stripMargin), identity, df.select( - 'a.cast("string"), - 'b.cast("string"), - 'c.cast("string"), - 'd.cast("string"), - 'e.cast("string")).collect()) + Symbol("a").cast("string"), + Symbol("b").cast("string"), + Symbol("c").cast("string"), + Symbol("d").cast("string"), + Symbol("e").cast("string")).collect()) // input/output with different delimit and show result checkAnswer( @@ -389,11 +391,11 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU |FROM v """.stripMargin), identity, df.select( concat_ws(",", - 'a.cast("string"), - 'b.cast("string"), - 'c.cast("string"), - 'd.cast("string"), - 'e.cast("string"))).collect()) + Symbol("a").cast("string"), + Symbol("b").cast("string"), + Symbol("c").cast("string"), + Symbol("d").cast("string"), + Symbol("e").cast("string"))).collect()) } } @@ -421,8 +423,8 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU ioschema = defaultIOSchema ), df.select( - 'a.cast("string").as("a"), - 'b.cast("string").as("b"), + Symbol("a").cast("string").as("a"), + Symbol("b").cast("string").as("b"), lit(null), lit(null)).collect()) } @@ -495,11 +497,11 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU | FIELDS TERMINATED BY '\t' |FROM v """.stripMargin), identity, df.select( - 'a.cast("string"), - 'b.cast("string"), - 'c.cast("string"), - 'd.cast("string"), - 'e.cast("string")).collect()) + Symbol("a").cast("string"), + Symbol("b").cast("string"), + Symbol("c").cast("string"), + Symbol("d").cast("string"), + Symbol("e").cast("string")).collect()) // test '/path/to/script.py' with script not executable val e1 = intercept[TestFailedException] { @@ -515,11 +517,11 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU | FIELDS TERMINATED BY '\t' |FROM v """.stripMargin), identity, df.select( - 'a.cast("string"), - 'b.cast("string"), - 'c.cast("string"), - 'd.cast("string"), - 'e.cast("string")).collect()) + Symbol("a").cast("string"), + Symbol("b").cast("string"), + Symbol("c").cast("string"), + Symbol("d").cast("string"), + Symbol("e").cast("string")).collect()) }.getMessage // Check with status exit code since in GA test, it may lose detail failed root cause. // Different root cause's exitcode is not same. @@ -540,11 +542,11 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU | FIELDS TERMINATED BY '\t' |FROM v """.stripMargin), identity, df.select( - 'a.cast("string"), - 'b.cast("string"), - 'c.cast("string"), - 'd.cast("string"), - 'e.cast("string")).collect()) + Symbol("a").cast("string"), + Symbol("b").cast("string"), + Symbol("c").cast("string"), + Symbol("d").cast("string"), + Symbol("e").cast("string")).collect()) scriptFilePath.setExecutable(false) sql(s"ADD FILE ${scriptFilePath.getAbsolutePath}") @@ -561,11 +563,11 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU | FIELDS TERMINATED BY '\t' |FROM v """.stripMargin), identity, df.select( - 'a.cast("string"), - 'b.cast("string"), - 'c.cast("string"), - 'd.cast("string"), - 'e.cast("string")).collect()) + Symbol("a").cast("string"), + Symbol("b").cast("string"), + Symbol("c").cast("string"), + Symbol("d").cast("string"), + Symbol("e").cast("string")).collect()) // test `python script.py` when file added checkAnswer( @@ -579,11 +581,11 @@ abstract class BaseScriptTransformationSuite extends SparkPlanTest with SQLTestU | FIELDS TERMINATED BY '\t' |FROM v """.stripMargin), identity, df.select( - 'a.cast("string"), - 'b.cast("string"), - 'c.cast("string"), - 'd.cast("string"), - 'e.cast("string")).collect()) + Symbol("a").cast("string"), + Symbol("b").cast("string"), + Symbol("c").cast("string"), + Symbol("d").cast("string"), + Symbol("e").cast("string")).collect()) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala index 4ff96e6574ca..3f0d6930dcd4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala @@ -26,9 +26,11 @@ class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper { test("basic") { val leftInput = Seq(create_row(1, "a"), create_row(1, "b"), create_row(2, "c")).iterator val rightInput = Seq(create_row(1, 2L), create_row(2, 3L), create_row(3, 4L)).iterator - val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string)) - val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long)) - val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int)) + val leftGrouped = GroupedIterator(leftInput, Seq(Symbol("i").int.at(0)), + Seq(Symbol("i").int, Symbol("s").string)) + val rightGrouped = GroupedIterator(rightInput, Seq(Symbol("i").int.at(0)), + Seq(Symbol("i").int, Symbol("l").long)) + val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq(Symbol("i").int)) val result = cogrouped.map { case (key, leftData, rightData) => @@ -52,9 +54,11 @@ class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-11393: respect the fact that GroupedIterator.hasNext is not idempotent") { val leftInput = Seq(create_row(2, "a")).iterator val rightInput = Seq(create_row(1, 2L)).iterator - val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string)) - val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long)) - val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int)) + val leftGrouped = GroupedIterator(leftInput, Seq(Symbol("i").int.at(0)), + Seq(Symbol("i").int, Symbol("s").string)) + val rightGrouped = GroupedIterator(rightInput, Seq(Symbol("i").int.at(0)), + Seq(Symbol("i").int, Symbol("l").long)) + val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq(Symbol("i").int)) val result = cogrouped.map { case (key, leftData, rightData) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala index 4b2a2b439c89..06c51cee0201 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala @@ -32,7 +32,7 @@ class GroupedIteratorSuite extends SparkFunSuite { val fromRow = encoder.createDeserializer() val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) val grouped = GroupedIterator(input.iterator.map(toRow), - Seq('i.int.at(0)), schema.toAttributes) + Seq(Symbol("i").int.at(0)), schema.toAttributes) val result = grouped.map { case (key, data) => @@ -59,7 +59,7 @@ class GroupedIteratorSuite extends SparkFunSuite { Row(3, 2L, "e")) val grouped = GroupedIterator(input.iterator.map(toRow), - Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes) + Seq(Symbol("i").int.at(0), Symbol("l").long.at(1)), schema.toAttributes) val result = grouped.map { case (key, data) => @@ -80,7 +80,7 @@ class GroupedIteratorSuite extends SparkFunSuite { val toRow = encoder.createSerializer() val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) val grouped = GroupedIterator(input.iterator.map(toRow), - Seq('i.int.at(0)), schema.toAttributes) + Seq(Symbol("i").int.at(0)), schema.toAttributes) assert(grouped.length == 2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index e851722fa4ea..43d09075b541 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -56,18 +56,19 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } test("count is partially aggregated") { - val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed + val query = testData.groupBy(Symbol("value")).agg(count(Symbol("key"))).queryExecution.analyzed testPartialAggregationPlan(query) } test("count distinct is partially aggregated") { - val query = testData.groupBy('value).agg(count_distinct('key)).queryExecution.analyzed + val query = + testData.groupBy(Symbol("value")).agg(count_distinct(Symbol("key"))).queryExecution.analyzed testPartialAggregationPlan(query) } test("mixed aggregates are partially aggregated") { - val query = - testData.groupBy('value).agg(count('value), count_distinct('key)).queryExecution.analyzed + val query = testData.groupBy(Symbol("value")).agg(count(Symbol("value")), + count_distinct(Symbol("key"))).queryExecution.analyzed testPartialAggregationPlan(query) } @@ -190,45 +191,47 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { } test("efficient terminal limit -> sort should use TakeOrderedAndProject") { - val query = testData.select('key, 'value).sort('key).limit(2) + val query = testData.select(Symbol("key"), Symbol("value")).sort(Symbol("key")).limit(2) val planned = query.queryExecution.executedPlan assert(planned.isInstanceOf[execution.TakeOrderedAndProjectExec]) - assert(planned.output === testData.select('key, 'value).logicalPlan.output) + assert(planned.output === testData.select(Symbol("key"), Symbol("value")).logicalPlan.output) } test("terminal limit -> project -> sort should use TakeOrderedAndProject") { - val query = testData.select('key, 'value).sort('key).select('value, 'key).limit(2) + val query = testData.select(Symbol("key"), + Symbol("value")).sort(Symbol("key")).select(Symbol("value"), Symbol("key")).limit(2) val planned = query.queryExecution.executedPlan assert(planned.isInstanceOf[execution.TakeOrderedAndProjectExec]) - assert(planned.output === testData.select('value, 'key).logicalPlan.output) + assert(planned.output === testData.select(Symbol("value"), Symbol("key")).logicalPlan.output) } test("terminal limits that are not handled by TakeOrderedAndProject should use CollectLimit") { - val query = testData.select('value).limit(2) + val query = testData.select(Symbol("value")).limit(2) val planned = query.queryExecution.sparkPlan assert(planned.isInstanceOf[CollectLimitExec]) - assert(planned.output === testData.select('value).logicalPlan.output) + assert(planned.output === testData.select(Symbol("value")).logicalPlan.output) } test("TakeOrderedAndProject can appear in the middle of plans") { - val query = testData.select('key, 'value).sort('key).limit(2).filter('key === 3) + val query = testData.select(Symbol("key"), + Symbol("value")).sort(Symbol("key")).limit(2).filter(Symbol("key") === 3) val planned = query.queryExecution.executedPlan assert(planned.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined) } test("CollectLimit can appear in the middle of a plan when caching is used") { - val query = testData.select('key, 'value).limit(2).cache() + val query = testData.select(Symbol("key"), Symbol("value")).limit(2).cache() val planned = query.queryExecution.optimizedPlan.asInstanceOf[InMemoryRelation] assert(planned.cachedPlan.isInstanceOf[CollectLimitExec]) } test("TakeOrderedAndProjectExec appears only when number of limit is below the threshold.") { withSQLConf(SQLConf.TOP_K_SORT_FALLBACK_THRESHOLD.key -> "1000") { - val query0 = testData.select('value).orderBy('key).limit(100) + val query0 = testData.select(Symbol("value")).orderBy(Symbol("key")).limit(100) val planned0 = query0.queryExecution.executedPlan assert(planned0.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isDefined) - val query1 = testData.select('value).orderBy('key).limit(2000) + val query1 = testData.select(Symbol("value")).orderBy(Symbol("key")).limit(2000) val planned1 = query1.queryExecution.executedPlan assert(planned1.find(_.isInstanceOf[TakeOrderedAndProjectExec]).isEmpty) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala index 751078d08fda..5830f8b179e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RemoveRedundantSortsSuite.scala @@ -51,7 +51,7 @@ abstract class RemoveRedundantSortsSuiteBase test("remove redundant sorts with limit") { withTempView("t") { - spark.range(100).select('id as "key").createOrReplaceTempView("t") + spark.range(100).select(Symbol("id") as "key").createOrReplaceTempView("t") val query = """ |SELECT key FROM @@ -64,8 +64,8 @@ abstract class RemoveRedundantSortsSuiteBase test("remove redundant sorts with broadcast hash join") { withTempView("t1", "t2") { - spark.range(1000).select('id as "key").createOrReplaceTempView("t1") - spark.range(1000).select('id as "key").createOrReplaceTempView("t2") + spark.range(1000).select(Symbol("id") as "key").createOrReplaceTempView("t1") + spark.range(1000).select(Symbol("id") as "key").createOrReplaceTempView("t2") val queryTemplate = """ |SELECT /*+ BROADCAST(%s) */ t1.key FROM @@ -100,8 +100,8 @@ abstract class RemoveRedundantSortsSuiteBase test("remove redundant sorts with sort merge join") { withTempView("t1", "t2") { - spark.range(1000).select('id as "key").createOrReplaceTempView("t1") - spark.range(1000).select('id as "key").createOrReplaceTempView("t2") + spark.range(1000).select(Symbol("id") as "key").createOrReplaceTempView("t1") + spark.range(1000).select(Symbol("id") as "key").createOrReplaceTempView("t2") val query = """ |SELECT /*+ MERGE(t1) */ t1.key FROM | (SELECT key FROM t1 WHERE key > 10 ORDER BY key DESC LIMIT 10) t1 @@ -123,15 +123,15 @@ abstract class RemoveRedundantSortsSuiteBase test("cached sorted data doesn't need to be re-sorted") { withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "true") { - val df = spark.range(1000).select('id as "key").sort('key.desc).cache() - val resorted = df.sort('key.desc) - val sortedAsc = df.sort('key.asc) + val df = spark.range(1000).select(Symbol("id") as "key").sort(Symbol("key").desc).cache() + val resorted = df.sort(Symbol("key").desc) + val sortedAsc = df.sort(Symbol("key").asc) checkNumSorts(df, 0) checkNumSorts(resorted, 0) checkNumSorts(sortedAsc, 1) val result = resorted.collect() withSQLConf(SQLConf.REMOVE_REDUNDANT_SORTS_ENABLED.key -> "false") { - val resorted = df.sort('key.desc) + val resorted = df.sort(Symbol("key").desc) checkNumSorts(resorted, 1) checkAnswer(resorted, result) } @@ -140,7 +140,7 @@ abstract class RemoveRedundantSortsSuiteBase test("SPARK-33472: shuffled join with different left and right side partition numbers") { withTempView("t1", "t2") { - spark.range(0, 100, 1, 2).select('id as "key").createOrReplaceTempView("t1") + spark.range(0, 100, 1, 2).select(Symbol("id") as "key").createOrReplaceTempView("t1") (0 to 100).toDF("key").createOrReplaceTempView("t2") val queryTemplate = """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index d672a75a21a8..c214c4a73d9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -600,7 +600,8 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { spark.range(10).write.saveAsTable("add_col") withView("v") { sql("CREATE VIEW v AS SELECT * FROM add_col") - spark.range(10).select('id, 'id as 'a).write.mode("overwrite").saveAsTable("add_col") + spark.range(10).select(Symbol("id"), Symbol("id") as Symbol("a")) + .write.mode("overwrite").saveAsTable("add_col") checkAnswer(sql("SELECT * FROM v"), spark.range(10).toDF()) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index 6a4f3f62641f..a340bbcafcdf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -43,13 +43,15 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { checkAnswer( input.toDF("a", "b", "c"), - (child: SparkPlan) => SortExec('a.asc :: 'b.asc :: Nil, global = true, child = child), + (child: SparkPlan) => + SortExec(Symbol("a").asc :: Symbol("b").asc :: Nil, global = true, child = child), input.sortBy(t => (t._1, t._2)).map(Row.fromTuple), sortAnswers = false) checkAnswer( input.toDF("a", "b", "c"), - (child: SparkPlan) => SortExec('b.asc :: 'a.asc :: Nil, global = true, child = child), + (child: SparkPlan) => + SortExec(Symbol("b").asc :: Symbol("a").asc :: Nil, global = true, child = child), input.sortBy(t => (t._2, t._1)).map(Row.fromTuple), sortAnswers = false) } @@ -58,9 +60,9 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF().selectExpr("NULL as a"), (child: SparkPlan) => - GlobalLimitExec(10, SortExec('a.asc :: Nil, global = true, child = child)), + GlobalLimitExec(10, SortExec(Symbol("a").asc :: Nil, global = true, child = child)), (child: SparkPlan) => - GlobalLimitExec(10, ReferenceSort('a.asc :: Nil, global = true, child)), + GlobalLimitExec(10, ReferenceSort(Symbol("a").asc :: Nil, global = true, child)), sortAnswers = false ) } @@ -69,15 +71,15 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), (child: SparkPlan) => - GlobalLimitExec(10, SortExec('a.asc :: Nil, global = true, child = child)), + GlobalLimitExec(10, SortExec(Symbol("a").asc :: Nil, global = true, child = child)), (child: SparkPlan) => - GlobalLimitExec(10, ReferenceSort('a.asc :: Nil, global = true, child)), + GlobalLimitExec(10, ReferenceSort(Symbol("a").asc :: Nil, global = true, child)), sortAnswers = false ) } test("sorting does not crash for large inputs") { - val sortOrder = 'a.asc :: Nil + val sortOrder = Symbol("a").asc :: Nil val stringLength = 1024 * 1024 * 2 checkThatPlansAgree( Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1), @@ -91,8 +93,8 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => SortExec('a.asc :: Nil, global = true, child = child), - (child: SparkPlan) => ReferenceSort('a.asc :: Nil, global = true, child), + (child: SparkPlan) => SortExec(Symbol("a").asc :: Nil, global = true, child = child), + (child: SparkPlan) => ReferenceSort(Symbol("a").asc :: Nil, global = true, child), sortAnswers = false) } } @@ -105,7 +107,9 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { ) checkAnswer( input.toDF("a", "b", "c"), - (child: SparkPlan) => SortExec(Stream('a.asc, 'b.asc, 'c.asc), global = true, child = child), + (child: SparkPlan) => + SortExec( + Stream(Symbol("a").asc, Symbol("b").asc, Symbol("c").asc), global = true, child = child), input.sortBy(t => (t._1, t._2, t._3)).map(Row.fromTuple), sortAnswers = false) } @@ -114,8 +118,8 @@ class SortSuite extends SparkPlanTest with SharedSparkSession { for ( dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType); nullable <- Seq(true, false); - sortOrder <- - Seq('a.asc :: Nil, 'a.asc_nullsLast :: Nil, 'a.desc :: Nil, 'a.desc_nullsFirst :: Nil); + sortOrder <- Seq(Symbol("a").asc :: Nil, Symbol("a").asc_nullsLast :: Nil, + Symbol("a").desc :: Nil, Symbol("a").desc_nullsFirst :: Nil); randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index f1788e9c31af..8ac1e25231ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -281,7 +281,7 @@ class SparkSqlParserSuite extends AnalysisTest { |FROM testData """.stripMargin, ScriptTransformation( - Seq('a, 'b, 'c), + Seq(Symbol("a"), Symbol("b"), Symbol("c")), "cat", Seq(AttributeReference("a", StringType)(), AttributeReference("b", StringType)(), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala index 0ed0126add7a..31a949178241 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SubExprEliminationBenchmark.scala @@ -49,7 +49,7 @@ object SubExprEliminationBenchmark extends SqlBasedBenchmark { val schema = writeWideRow(path.getAbsolutePath, rowsNum, numCols) val cols = (0 until numCols).map { idx => - from_json('value, schema).getField(s"col$idx") + from_json(Symbol("value"), schema).getField(s"col$idx") } Seq( @@ -88,7 +88,7 @@ object SubExprEliminationBenchmark extends SqlBasedBenchmark { val schema = writeWideRow(path.getAbsolutePath, rowsNum, numCols) val predicate = (0 until numCols).map { idx => - (from_json('value, schema).getField(s"col$idx") >= Literal(100000)).expr + (from_json(Symbol("value"), schema).getField(s"col$idx") >= Literal(100000)).expr }.asInstanceOf[Seq[Expression]].reduce(Or) Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala index 376d330ebeb7..07264422769b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala @@ -52,7 +52,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSparkSession { private def noOpFilter(plan: SparkPlan): SparkPlan = FilterExec(Literal(true), plan) val limit = 250 - val sortOrder = 'a.desc :: 'b.desc :: Nil + val sortOrder = Symbol("a").desc :: Symbol("b").desc :: Nil test("TakeOrderedAndProject.doExecute without project") { withClue(s"seed = $seed") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index c073df2017a5..54acb8758e2c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -234,7 +234,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession import testImplicits._ withTempPath { dir => val path = dir.getCanonicalPath - val df = spark.range(10).select(Seq.tabulate(201) {i => ('id + i).as(s"c$i")} : _*) + val df = spark.range(10).select(Seq.tabulate(201) {i => (Symbol("id") + i).as(s"c$i")} : _*) df.write.mode(SaveMode.Overwrite).parquet(path) withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "202", @@ -251,7 +251,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession test("Control splitting consume function by operators with config") { import testImplicits._ - val df = spark.range(10).select(Seq.tabulate(2) {i => ('id + i).as(s"c$i")} : _*) + val df = spark.range(10).select(Seq.tabulate(2) {i => (Symbol("id") + i).as(s"c$i")} : _*) Seq(true, false).foreach { config => withSQLConf(SQLConf.WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR.key -> s"$config") { @@ -314,9 +314,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_USE_ID_IN_CLASS_NAME.key -> "true") { // the same query run twice should produce identical code, which would imply a hit in // the generated code cache. - val ds1 = spark.range(3).select('id + 2) + val ds1 = spark.range(3).select(Symbol("id") + 2) val code1 = genCode(ds1) - val ds2 = spark.range(3).select('id + 2) + val ds2 = spark.range(3).select(Symbol("id") + 2) val code2 = genCode(ds2) // same query shape as above, deliberately assert(code1 == code2, "Should produce same code") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 122bc2d1e59a..ccc777a9649c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -228,11 +228,12 @@ class AdaptiveQueryExecSuite withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true") { - val df1 = spark.range(10).withColumn("a", 'id) - val df2 = spark.range(10).withColumn("b", 'id) + val df1 = spark.range(10).withColumn("a", Symbol("id")) + val df2 = spark.range(10).withColumn("b", Symbol("id")) withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer") - .groupBy('a).count() + val testDf = + df1.where(Symbol("a") > 10).join(df2.where(Symbol("b") > 10), Seq("id"), "left_outer") + .groupBy(Symbol("a")).count() checkAnswer(testDf, Seq()) val plan = testDf.queryExecution.executedPlan assert(find(plan)(_.isInstanceOf[SortMergeJoinExec]).isDefined) @@ -244,8 +245,9 @@ class AdaptiveQueryExecSuite } withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { - val testDf = df1.where('a > 10).join(df2.where('b > 10), Seq("id"), "left_outer") - .groupBy('a).count() + val testDf = + df1.where(Symbol("a") > 10).join(df2.where(Symbol("b") > 10), Seq("id"), "left_outer") + .groupBy(Symbol("a")).count() checkAnswer(testDf, Seq()) val plan = testDf.queryExecution.executedPlan assert(find(plan)(_.isInstanceOf[BroadcastHashJoinExec]).isDefined) @@ -693,17 +695,17 @@ class AdaptiveQueryExecSuite spark .range(0, 1000, 1, 10) .select( - when('id < 250, 249) - .when('id >= 750, 1000) - .otherwise('id).as("key1"), - 'id as "value1") + when(Symbol("id") < 250, 249) + .when(Symbol("id") >= 750, 1000) + .otherwise(Symbol("id")).as("key1"), + Symbol("id") as "value1") .createOrReplaceTempView("skewData1") spark .range(0, 1000, 1, 10) .select( - when('id < 250, 249) - .otherwise('id).as("key2"), - 'id as "value2") + when(Symbol("id") < 250, 249) + .otherwise(Symbol("id")).as("key2"), + Symbol("id") as "value2") .createOrReplaceTempView("skewData2") def checkSkewJoin( @@ -913,17 +915,17 @@ class AdaptiveQueryExecSuite spark .range(0, 1000, 1, 10) .select( - when('id < 250, 249) - .when('id >= 750, 1000) - .otherwise('id).as("key1"), - 'id as "value1") + when(Symbol("id") < 250, 249) + .when(Symbol("id") >= 750, 1000) + .otherwise(Symbol("id")).as("key1"), + Symbol("id") as "value1") .createOrReplaceTempView("skewData1") spark .range(0, 1000, 1, 10) .select( - when('id < 250, 249) - .otherwise('id).as("key2"), - 'id as "value2") + when(Symbol("id") < 250, 249) + .otherwise(Symbol("id")).as("key2"), + Symbol("id") as "value2") .createOrReplaceTempView("skewData2") val (_, adaptivePlan) = runAdaptiveAndVerifyResult( "SELECT * FROM skewData1 join skewData2 ON key1 = key2") @@ -998,7 +1000,7 @@ class AdaptiveQueryExecSuite test("AQE should set active session during execution") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val df = spark.range(10).select(sum('id)) + val df = spark.range(10).select(sum(Symbol("id"))) assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec]) SparkSession.setActiveSession(null) checkAnswer(df, Seq(Row(45))) @@ -1025,7 +1027,7 @@ class AdaptiveQueryExecSuite SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true") { try { spark.experimental.extraStrategies = TestStrategy :: Nil - val df = spark.range(10).groupBy('id).count() + val df = spark.range(10).groupBy(Symbol("id")).count() df.collect() } finally { spark.experimental.extraStrategies = Nil @@ -1311,7 +1313,7 @@ class AdaptiveQueryExecSuite test("SPARK-33494: Do not use local shuffle reader for repartition") { withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - val df = spark.table("testData").repartition('key) + val df = spark.table("testData").repartition(Symbol("key")) df.collect() // local shuffle reader breaks partitioning and shouldn't be used for repartition operation // which is specified by users. @@ -1341,7 +1343,7 @@ class AdaptiveQueryExecSuite withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "80") { // Repartition with no partition num specified. - val dfRepartition = df.repartition('b) + val dfRepartition = df.repartition(Symbol("b")) dfRepartition.collect() val plan = dfRepartition.queryExecution.executedPlan // The top shuffle from repartition is optimized out. @@ -1355,7 +1357,7 @@ class AdaptiveQueryExecSuite assert(customReader.get.asInstanceOf[CustomShuffleReaderExec].hasCoalescedPartition) // Repartition with partition default num specified. - val dfRepartitionWithNum = df.repartition(5, 'b) + val dfRepartitionWithNum = df.repartition(5, Symbol("b")) dfRepartitionWithNum.collect() val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan // The top shuffle from repartition is optimized out. @@ -1367,7 +1369,7 @@ class AdaptiveQueryExecSuite assert(bhjWithNum.head.right.find(_.isInstanceOf[CustomShuffleReaderExec]).isEmpty) // Repartition with partition non-default num specified. - val dfRepartitionWithNum2 = df.repartition(3, 'b) + val dfRepartitionWithNum2 = df.repartition(3, Symbol("b")) dfRepartitionWithNum2.collect() val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan // The top shuffle from repartition is not optimized out, and this is the only shuffle that @@ -1388,7 +1390,7 @@ class AdaptiveQueryExecSuite SQLConf.SKEW_JOIN_SKEWED_PARTITION_FACTOR.key -> "0", SQLConf.ADVISORY_PARTITION_SIZE_IN_BYTES.key -> "10") { // Repartition with no partition num specified. - val dfRepartition = df.repartition('b) + val dfRepartition = df.repartition(Symbol("b")) dfRepartition.collect() val plan = dfRepartition.queryExecution.executedPlan // The top shuffle from repartition is optimized out. @@ -1404,7 +1406,7 @@ class AdaptiveQueryExecSuite assert(customReaders.length == 2) // Repartition with default partition num specified. - val dfRepartitionWithNum = df.repartition(5, 'b) + val dfRepartitionWithNum = df.repartition(5, Symbol("b")) dfRepartitionWithNum.collect() val planWithNum = dfRepartitionWithNum.queryExecution.executedPlan // The top shuffle from repartition is optimized out. @@ -1420,7 +1422,7 @@ class AdaptiveQueryExecSuite assert(customReadersWithNum.isEmpty) // Repartition with default non-partition num specified. - val dfRepartitionWithNum2 = df.repartition(3, 'b) + val dfRepartitionWithNum2 = df.repartition(3, Symbol("b")) dfRepartitionWithNum2.collect() val planWithNum2 = dfRepartitionWithNum2.queryExecution.executedPlan // The top shuffle from repartition is not optimized out. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala index e566f5d5adee..c3676164235c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/RangeBenchmark.scala @@ -48,7 +48,7 @@ object RangeBenchmark extends SqlBasedBenchmark { } benchmark.addCase("filter after range", numIters = 4) { _ => - spark.range(N).filter('id % 100 === 0).noop() + spark.range(N).filter(Symbol("id") % 100 === 0).noop() } benchmark.addCase("count after range", numIters = 4) { _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index b8f73f4563ef..5f518aa53e76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -152,7 +152,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSparkSession { } test("projection") { - val logicalPlan = testData.select('value, 'key).logicalPlan + val logicalPlan = testData.select(Symbol("value"), Symbol("key")).logicalPlan val plan = spark.sessionState.executePlan(logicalPlan).sparkPlan val scan = InMemoryRelation(new TestCachedBatchSerializer(useCompression = true, 5), MEMORY_ONLY, plan, None, logicalPlan) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 96f9421e1d98..918b4932ddac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -284,11 +284,12 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { "func", Seq.empty, plans.table("e"), null) compareTransformQuery("select transform(a, b) using 'func' from e where f < 10", - p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + p.copy(child = p.child.where(Symbol("f") < 10), + output = Seq(Symbol("key").string, Symbol("value").string))) compareTransformQuery("map a, b using 'func' as c, d from e", - p.copy(output = Seq('c.string, 'd.string))) + p.copy(output = Seq(Symbol("c").string, Symbol("d").string))) compareTransformQuery("reduce a, b using 'func' as (c int, d decimal(10, 0)) from e", - p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + p.copy(output = Seq(Symbol("c").int, Symbol("d").decimal(10, 0)))) } test("use backticks in output of Script Transform") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 45406835453b..a52e34ad2365 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -115,7 +115,8 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { }.getMessage assert(e.contains("Hive support is required to CREATE Hive TABLE (AS SELECT)")) - spark.range(1).select('id as 'a, 'id as 'b).write.saveAsTable("t1") + spark.range(1).select(Symbol("id") as Symbol("a"), Symbol("id") as Symbol("b")) + .write.saveAsTable("t1") e = intercept[AnalysisException] { sql("CREATE TABLE t STORED AS parquet SELECT a, b from t1") }.getMessage @@ -1624,7 +1625,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("CREATE TABLE t USING parquet SELECT 1 as a, 1 as b") checkAnswer(spark.table("t"), Row(1, 1) :: Nil) - spark.range(1).select('id as 'a, 'id as 'b).write.saveAsTable("t1") + spark.range(1).select(Symbol("id") as Symbol("a"), Symbol("id") as Symbol("b")) + .write.saveAsTable("t1") sql("CREATE TABLE t2 USING parquet SELECT a, b from t1") checkAnswer(spark.table("t2"), spark.table("t1")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index b94918eccd46..2085feb31fa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -26,12 +26,12 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT class DataSourceStrategySuite extends PlanTest with SharedSparkSession { val attrInts = Seq( - 'cint.int, + Symbol("cint").int, Symbol("c.int").int, - GetStructField('a.struct(StructType( + GetStructField(Symbol("a").struct(StructType( StructField("cstr", StringType, nullable = true) :: StructField("cint", IntegerType, nullable = true) :: Nil)), 1, None), - GetStructField('a.struct(StructType( + GetStructField(Symbol("a").struct(StructType( StructField("c.int", IntegerType, nullable = true) :: StructField("cstr", StringType, nullable = true) :: Nil)), 0, None), GetStructField(Symbol("a.b").struct(StructType( @@ -40,7 +40,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { StructField("cint", IntegerType, nullable = true) :: Nil)), 2, None), GetStructField(Symbol("a.b").struct(StructType( StructField("c.int", IntegerType, nullable = true) :: Nil)), 0, None), - GetStructField(GetStructField('a.struct(StructType( + GetStructField(GetStructField(Symbol("a").struct(StructType( StructField("cstr1", StringType, nullable = true) :: StructField("b", StructType(StructField("cint", IntegerType, nullable = true) :: StructField("cstr2", StringType, nullable = true) :: Nil)) :: Nil)), 1, None), 0, None) @@ -55,12 +55,12 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { )) val attrStrs = Seq( - 'cstr.string, + Symbol("cstr").string, Symbol("c.str").string, - GetStructField('a.struct(StructType( + GetStructField(Symbol("a").struct(StructType( StructField("cint", IntegerType, nullable = true) :: StructField("cstr", StringType, nullable = true) :: Nil)), 1, None), - GetStructField('a.struct(StructType( + GetStructField(Symbol("a").struct(StructType( StructField("c.str", StringType, nullable = true) :: StructField("cint", IntegerType, nullable = true) :: Nil)), 0, None), GetStructField(Symbol("a.b").struct(StructType( @@ -69,7 +69,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { StructField("cstr", StringType, nullable = true) :: Nil)), 2, None), GetStructField(Symbol("a.b").struct(StructType( StructField("c.str", StringType, nullable = true) :: Nil)), 0, None), - GetStructField(GetStructField('a.struct(StructType( + GetStructField(GetStructField(Symbol("a").struct(StructType( StructField("cint1", IntegerType, nullable = true) :: StructField("b", StructType(StructField("cstr", StringType, nullable = true) :: StructField("cint2", IntegerType, nullable = true) :: Nil)) :: Nil)), 1, None), 0, None) @@ -280,7 +280,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { }} test("SPARK-26865 DataSourceV2Strategy should push normalized filters") { - val attrInt = 'cint.int + val attrInt = Symbol("cint").int assertResult(Seq(IsNotNull(attrInt))) { DataSourceStrategy.normalizeExprs(Seq(IsNotNull(attrInt.withName("CiNt"))), Seq(attrInt)) } @@ -308,7 +308,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { } // `Abs(col)` can not be pushed down, so it returns `None` - assert(PushableColumnAndNestedColumn.unapply(Abs('col.int)) === None) + assert(PushableColumnAndNestedColumn.unapply(Abs(Symbol("col").int)) === None) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala index 6ba3d2723412..3034d4fe67c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceSuite.scala @@ -143,7 +143,8 @@ class DataSourceSuite extends SharedSparkSession with PrivateMethodTester { test("Data source options should be propagated in method checkAndGlobPathIfNecessary") { val dataSourceOptions = Map("fs.defaultFS" -> "nonexistentFs://nonexistentFs") val dataSource = DataSource(spark, "parquet", Seq("/path3"), options = dataSourceOptions) - val checkAndGlobPathIfNecessary = PrivateMethod[Seq[Path]]('checkAndGlobPathIfNecessary) + val checkAndGlobPathIfNecessary = + PrivateMethod[Seq[Path]](Symbol("checkAndGlobPathIfNecessary")) val message = intercept[java.io.IOException] { dataSource invokePrivate checkAndGlobPathIfNecessary(false, false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala index f492fc653653..663b7c9efd99 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala @@ -39,12 +39,14 @@ class FileFormatWriterSuite test("SPARK-22252: FileFormatWriter should respect the input query schema") { withTable("t1", "t2", "t3", "t4") { - spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1") + spark.range(1).select(Symbol("id") as Symbol("col1"), Symbol("id") as Symbol("col2")) + .write.saveAsTable("t1") spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2") checkAnswer(spark.table("t2"), Row(0, 0)) // Test picking part of the columns when writing. - spark.range(1).select('id, 'id as 'col1, 'id as 'col2).write.saveAsTable("t3") + spark.range(1).select(Symbol("id"), Symbol("id") as Symbol("col1"), + Symbol("id") as Symbol("col2")).write.saveAsTable("t3") spark.sql("select COL1, COL2 from t3").write.saveAsTable("t4") checkAnswer(spark.table("t4"), Row(0, 0)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 50f32126e5de..4fb28a14637c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -60,7 +60,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre "file9" -> 1, "file10" -> 1)) - checkScan(table.select('c1)) { partitions => + checkScan(table.select(Symbol("c1"))) { partitions => // 10 one byte files should fit in a single partition with 10 files. assert(partitions.size == 1, "when checking partitions") assert(partitions.head.files.size == 10, "when checking partition 1") @@ -83,7 +83,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "11", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { - checkScan(table.select('c1)) { partitions => + checkScan(table.select(Symbol("c1"))) { partitions => // 5 byte files should be laid out [(5, 5), (5)] assert(partitions.size == 2, "when checking partitions") assert(partitions(0).files.size == 2, "when checking partition 1") @@ -108,7 +108,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { - checkScan(table.select('c1)) { partitions => + checkScan(table.select(Symbol("c1"))) { partitions => // Files should be laid out [(0-10), (10-15, 4)] assert(partitions.size == 2, "when checking partitions") assert(partitions(0).files.size == 1, "when checking partition 1") @@ -141,7 +141,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "4", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { - checkScan(table.select('c1)) { partitions => + checkScan(table.select(Symbol("c1"))) { partitions => // Files should be laid out [(file1), (file2, file3), (file4, file5), (file6)] assert(partitions.size == 4, "when checking partitions") assert(partitions(0).files.size == 1, "when checking partition 1") @@ -358,7 +358,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre withSQLConf( SQLConf.FILES_MAX_PARTITION_BYTES.key -> "2", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "0") { - checkScan(table.select('c1)) { partitions => + checkScan(table.select(Symbol("c1"))) { partitions => assert(partitions.size == 2) assert(partitions(0).files.size == 1) assert(partitions(1).files.size == 2) @@ -374,7 +374,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSparkSession with Pre withSQLConf( SQLConf.FILES_MAX_PARTITION_BYTES.key -> "2", SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "0") { - checkScan(table.select('c1)) { partitions => + checkScan(table.select(Symbol("c1"))) { partitions => assert(partitions.size == 3) assert(partitions(0).files.size == 1) assert(partitions(1).files.size == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index c90732183cb7..0dc8eb9c5848 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -532,7 +532,7 @@ abstract class SchemaPruningSuite Seq(Concat(Seq($"name.first", $"name.last")), Concat(Seq($"name.last", $"name.first"))) ), - Seq('a.string, 'b.string), + Seq(Symbol("a").string, Symbol("b").string), sql("select * from contacts").logicalPlan ).toDF() checkScan(query1, "struct>") @@ -549,7 +549,7 @@ abstract class SchemaPruningSuite val name = StructType.fromDDL("first string, middle string, last string") val query2 = Expand( Seq(Seq($"name", $"name.last")), - Seq('a.struct(name), 'b.string), + Seq(Symbol("a").struct(name), Symbol("b").string), sql("select * from contacts").logicalPlan ).toDF() checkScan(query2, "struct>") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 30f0e45d04ea..52f378979f90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1628,7 +1628,7 @@ abstract class CSVSuite val idf = spark.read .schema(schema) .csv(path.getCanonicalPath) - .select('f15, 'f10, 'f5) + .select(Symbol("f15"), Symbol("f10"), Symbol("f5")) assert(idf.count() == 2) checkAnswer(idf, List(Row(15, 10, 5), Row(-15, -10, -5))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala index ffe8e66f3368..aabf65852f94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonBenchmark.scala @@ -263,7 +263,7 @@ object JsonBenchmark extends SqlBasedBenchmark { benchmark.addCase("from_json", iters) { _ => val schema = new StructType().add("a", IntegerType) - val from_json_ds = in.select(from_json('value, schema)) + val from_json_ds = in.select(from_json(Symbol("value"), schema)) from_json_ds.noop() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopStreamSuite.scala index 3cb8287f09b2..b892a9e15581 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopStreamSuite.scala @@ -90,7 +90,7 @@ class NoopStreamSuite extends StreamTest { .option("numPartitions", "1") .option("rowsPerSecond", "5") .load() - .select('value) + .select(Symbol("value")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopSuite.scala index b4073bedf559..811953754953 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/noop/NoopSuite.scala @@ -42,7 +42,7 @@ class NoopSuite extends SharedSparkSession { withTempPath { dir => val path = dir.getCanonicalPath spark.range(numElems) - .select('id mod 10 as "key", 'id as "value") + .select(Symbol("id") mod 10 as "key", Symbol("id") as "value") .write .partitionBy("key") .parquet(path) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index ead2c2cf1b70..17669552f640 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -368,7 +368,7 @@ abstract class OrcQueryTest extends OrcTest { withTempPath { dir => val path = dir.getCanonicalPath - spark.range(0, 10).select('id as "Acol").write.orc(path) + spark.range(0, 10).select(Symbol("id") as "Acol").write.orc(path) spark.read.orc(path).schema("Acol") intercept[IllegalArgumentException] { spark.read.orc(path).schema("acol") @@ -413,19 +413,19 @@ abstract class OrcQueryTest extends OrcTest { s"No data was filtered for predicate: $pred") } - checkPredicate('a === 5, List(5).map(Row(_, null))) - checkPredicate('a <=> 5, List(5).map(Row(_, null))) - checkPredicate('a < 5, List(1, 3).map(Row(_, null))) - checkPredicate('a <= 5, List(1, 3, 5).map(Row(_, null))) - checkPredicate('a > 5, List(7, 9).map(Row(_, null))) - checkPredicate('a >= 5, List(5, 7, 9).map(Row(_, null))) - checkPredicate('a.isNull, List(null).map(Row(_, null))) - checkPredicate('b.isNotNull, List()) - checkPredicate('a.isin(3, 5, 7), List(3, 5, 7).map(Row(_, null))) - checkPredicate('a > 0 && 'a < 3, List(1).map(Row(_, null))) - checkPredicate('a < 1 || 'a > 8, List(9).map(Row(_, null))) - checkPredicate(!('a > 3), List(1, 3).map(Row(_, null))) - checkPredicate(!('a > 0 && 'a < 3), List(3, 5, 7, 9).map(Row(_, null))) + checkPredicate(Symbol("a") === 5, List(5).map(Row(_, null))) + checkPredicate(Symbol("a") <=> 5, List(5).map(Row(_, null))) + checkPredicate(Symbol("a") < 5, List(1, 3).map(Row(_, null))) + checkPredicate(Symbol("a") <= 5, List(1, 3, 5).map(Row(_, null))) + checkPredicate(Symbol("a") > 5, List(7, 9).map(Row(_, null))) + checkPredicate(Symbol("a") >= 5, List(5, 7, 9).map(Row(_, null))) + checkPredicate(Symbol("a").isNull, List(null).map(Row(_, null))) + checkPredicate(Symbol("b").isNotNull, List()) + checkPredicate(Symbol("a").isin(3, 5, 7), List(3, 5, 7).map(Row(_, null))) + checkPredicate(Symbol("a") > 0 && Symbol("a") < 3, List(1).map(Row(_, null))) + checkPredicate(Symbol("a") < 1 || Symbol("a") > 8, List(9).map(Row(_, null))) + checkPredicate(!(Symbol("a") > 3), List(1, 3).map(Row(_, null))) + checkPredicate(!(Symbol("a") > 0 && Symbol("a") < 3), List(3, 5, 7, 9).map(Row(_, null))) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala index d7727d93ddf9..916a21ff1004 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala @@ -268,7 +268,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with Shared } } - checkAnswer(spark.read.parquet(path).filter('suit === "SPADES"), Row("SPADES")) + checkAnswer(spark.read.parquet(path).filter(Symbol("suit") === "SPADES"), Row("SPADES")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index a546538adc56..1b7cc706c820 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -1327,39 +1327,39 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared test("filter pushdown - StringStartsWith") { withParquetDataFrame((1 to 4).map(i => Tuple1(i + "str" + i))) { implicit df => checkFilterPredicate( - '_1.startsWith("").asInstanceOf[Predicate], + Symbol("_1").startsWith("").asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], Seq("1str1", "2str2", "3str3", "4str4").map(Row(_))) Seq("2", "2s", "2st", "2str", "2str2").foreach { prefix => checkFilterPredicate( - '_1.startsWith(prefix).asInstanceOf[Predicate], + Symbol("_1").startsWith(prefix).asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], "2str2") } Seq("2S", "null", "2str22").foreach { prefix => checkFilterPredicate( - '_1.startsWith(prefix).asInstanceOf[Predicate], + Symbol("_1").startsWith(prefix).asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], Seq.empty[Row]) } checkFilterPredicate( - !'_1.startsWith("").asInstanceOf[Predicate], + !Symbol("_1").startsWith("").asInstanceOf[Predicate], classOf[Operators.Not], Seq().map(Row(_))) Seq("2", "2s", "2st", "2str", "2str2").foreach { prefix => checkFilterPredicate( - !'_1.startsWith(prefix).asInstanceOf[Predicate], + !Symbol("_1").startsWith(prefix).asInstanceOf[Predicate], classOf[Operators.Not], Seq("1str1", "3str3", "4str4").map(Row(_))) } Seq("2S", "null", "2str22").foreach { prefix => checkFilterPredicate( - !'_1.startsWith(prefix).asInstanceOf[Predicate], + !Symbol("_1").startsWith(prefix).asInstanceOf[Predicate], classOf[Operators.Not], Seq("1str1", "2str2", "3str3", "4str4").map(Row(_))) } @@ -1373,7 +1373,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared // SPARK-28371: make sure filter is null-safe. withParquetDataFrame(Seq(Tuple1[String](null))) { implicit df => checkFilterPredicate( - '_1.startsWith("blah").asInstanceOf[Predicate], + Symbol("_1").startsWith("blah").asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], Seq.empty[Row]) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 4f334f85ebf8..1c19ee2f9060 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -144,7 +144,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession .range(1000) // Parquet doesn't allow column names with spaces, have to add an alias here. // Minus 500 here so that negative decimals are also tested. - .select((('id - 500) / 100.0) cast decimal as 'dec) + .select(((Symbol("id") - 500) / 100.0) cast decimal as Symbol("dec")) .coalesce(1) } @@ -580,7 +580,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession withTempPath { dir => val m2 = intercept[SparkException] { - val df = spark.range(1).select('id as 'a, 'id as 'b).coalesce(1) + val df = spark.range(1).select(Symbol("id") as Symbol("a"), + Symbol("id") as Symbol("b")).coalesce(1) df.write.partitionBy("a").options(extraOptions).parquet(dir.getCanonicalPath) }.getCause.getMessage assert(m2.contains("Intentional exception for testing purposes")) @@ -646,7 +647,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession checkAnswer( // Decimal column in this file is encoded using plain dictionary readResourceParquetFile("test-data/dec-in-i32.parquet"), - spark.range(1 << 4).select('id % 10 cast DecimalType(5, 2) as 'i32_dec)) + spark.range(1 << 4).select(Symbol("id") % 10 cast DecimalType(5, 2) as Symbol("i32_dec"))) } } @@ -655,7 +656,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession checkAnswer( // Decimal column in this file is encoded using plain dictionary readResourceParquetFile("test-data/dec-in-i64.parquet"), - spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'i64_dec)) + spark.range(1 << 4).select(Symbol("id") % 10 cast DecimalType(10, 2) as Symbol("i64_dec"))) } } @@ -664,7 +665,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession checkAnswer( // Decimal column in this file is encoded using plain dictionary readResourceParquetFile("test-data/dec-in-fixed-len.parquet"), - spark.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) + spark.range(1 << 4) + .select(Symbol("id") % 10 cast DecimalType(10, 2) as Symbol("fixed_len_dec"))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 400f4d8e1b15..1ca39dc785d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -940,7 +940,8 @@ abstract class ParquetPartitionDiscoverySuite withTempPath { dir => withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") { val path = dir.getCanonicalPath - val df = spark.range(5).select('id as 'a, 'id as 'b, 'id as 'c).coalesce(1) + val df = spark.range(5).select(Symbol("id") as Symbol("a"), + Symbol("id") as Symbol("b"), Symbol("id") as Symbol("c")).coalesce(1) df.write.partitionBy("b", "c").parquet(path) checkAnswer(spark.read.parquet(path), df) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 8f85fe3c5258..7ba8e3dc4314 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -151,7 +151,7 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS (1, "2016-01-01 10:11:12.123456"), (2, null), (3, "1965-01-01 10:11:12.123456")) - .toDS().select('_1, $"_2".cast("timestamp")) + .toDS().select(Symbol("_1"), $"_2".cast("timestamp")) checkAnswer(sql("select * from ts"), expected) } } @@ -728,7 +728,7 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS test("SPARK-15804: write out the metadata to parquet file") { val df = Seq((1, "abc"), (2, "hello")).toDF("a", "b") val md = new MetadataBuilder().putString("key", "value").build() - val dfWithmeta = df.select('a, 'b.as("b", md)) + val dfWithmeta = df.select(Symbol("a"), Symbol("b").as("b", md)) withTempPath { dir => val path = dir.getCanonicalPath @@ -854,7 +854,7 @@ class ParquetV1QuerySuite extends ParquetQuerySuite { withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "10") { withTempPath { dir => val path = dir.getCanonicalPath - val df = spark.range(10).select(Seq.tabulate(11) {i => ('id + i).as(s"c$i")} : _*) + val df = spark.range(10).select(Seq.tabulate(11) {i => (Symbol("id") + i).as(s"c$i")} : _*) df.write.mode(SaveMode.Overwrite).parquet(path) // do not return batch - whole stage codegen is disabled for wide table (>200 columns) @@ -887,7 +887,7 @@ class ParquetV2QuerySuite extends ParquetQuerySuite { withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "10") { withTempPath { dir => val path = dir.getCanonicalPath - val df = spark.range(10).select(Seq.tabulate(11) {i => ('id + i).as(s"c$i")} : _*) + val df = spark.range(10).select(Seq.tabulate(11) {i => (Symbol("id") + i).as(s"c$i")} : _*) df.write.mode(SaveMode.Overwrite).parquet(path) // do not return batch - whole stage codegen is disabled for wide table (>200 columns) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index fcc08ee16e80..6f7880a68b7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -394,7 +394,8 @@ class ParquetSchemaSuite extends ParquetSchemaTest { withTempPath { dir => val path = dir.getCanonicalPath spark.range(3).write.parquet(s"$path/p=1") - spark.range(3).select('id cast IntegerType as 'id).write.parquet(s"$path/p=2") + spark.range(3).select(Symbol("id") cast IntegerType as Symbol("id")) + .write.parquet(s"$path/p=2") val message = intercept[SparkException] { spark.read.option("mergeSchema", "true").parquet(path).schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 98a1089709b9..8d583e106477 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -414,8 +414,8 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils test("Broadcast timeout") { val timeout = 5 val slowUDF = udf({ x: Int => Thread.sleep(timeout * 10 * 1000); x }) - val df1 = spark.range(10).select($"id" as 'a) - val df2 = spark.range(5).select(slowUDF($"id") as 'a) + val df1 = spark.range(10).select($"id" as Symbol("a")) + val df2 = spark.range(5).select(slowUDF($"id") as Symbol("a")) val testDf = df1.join(broadcast(df2), "a") withSQLConf(SQLConf.BROADCAST_TIMEOUT.key -> timeout.toString) { val e = intercept[Exception] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 2b703c06fa90..f4781d43c1ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -74,7 +74,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // Assume the execution plan is // PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0) Seq((0L, false), (1L, true)).foreach { case (nodeId, enableWholeStage) => - val df = person.filter('age < 25) + val df = person.filter(Symbol("age") < 25) testSparkPlanMetrics(df, 1, Map( nodeId -> (("Filter", Map( "number of output rows" -> 1L)))), @@ -89,7 +89,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // Filter(nodeId = 1) // Range(nodeId = 2) // TODO: update metrics in generated operators - val ds = spark.range(10).filter('id < 5) + val ds = spark.range(10).filter(Symbol("id") < 5) testSparkPlanMetricsWithPredicates(ds.toDF(), 1, Map( 0L -> (("WholeStageCodegen (1)", Map( "duration" -> { @@ -121,7 +121,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils ) // 2 partitions and each partition contains 2 keys - val df2 = testData2.groupBy('a).count() + val df2 = testData2.groupBy(Symbol("a")).count() val expected2 = Seq( Map("number of output rows" -> 4L, "avg hash probe bucket list iters" -> @@ -167,7 +167,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // Exchange(nodeId = 5) // LocalTableScan(nodeId = 6) Seq(true, false).foreach { enableWholeStage => - val df = generateRandomBytesDF().repartition(1).groupBy('a).count() + val df = generateRandomBytesDF().repartition(1).groupBy(Symbol("a")).count() val nodeIds = if (enableWholeStage) { Set(4L, 1L) } else { @@ -195,7 +195,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // Assume the execution plan is // ... -> ObjectHashAggregate(nodeId = 2) -> Exchange(nodeId = 1) // -> ObjectHashAggregate(nodeId = 0) - val df = testData2.groupBy().agg(collect_set('a)) // 2 partitions + val df = testData2.groupBy().agg(collect_set(Symbol("a"))) // 2 partitions testSparkPlanMetrics(df, 1, Map( 2L -> (("ObjectHashAggregate", Map("number of output rows" -> 2L))), 1L -> (("Exchange", Map( @@ -207,7 +207,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils ) // 2 partitions and each partition contains 2 keys - val df2 = testData2.groupBy('a).agg(collect_set('a)) + val df2 = testData2.groupBy(Symbol("a")).agg(collect_set(Symbol("a"))) testSparkPlanMetrics(df2, 1, Map( 2L -> (("ObjectHashAggregate", Map( "number of output rows" -> 4L, @@ -224,7 +224,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // 2 partitions and each partition contains 2 keys, with fallback to sort-based aggregation withSQLConf(SQLConf.OBJECT_AGG_SORT_BASED_FALLBACK_THRESHOLD.key -> "1") { - val df3 = testData2.groupBy('a).agg(collect_set('a)) + val df3 = testData2.groupBy(Symbol("a")).agg(collect_set(Symbol("a"))) testSparkPlanMetrics(df3, 1, Map( 2L -> (("ObjectHashAggregate", Map( "number of output rows" -> 4L, @@ -254,7 +254,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils // LocalTableScan(nodeId = 3) // Because of SPARK-25267, ConvertToLocalRelation is disabled in the test cases of sql/core, // so Project here is not collapsed into LocalTableScan. - val df = Seq(1, 3, 2).toDF("id").sort('id) + val df = Seq(1, 3, 2).toDF("id").sort(Symbol("id")) testSparkPlanMetricsWithPredicates(df, 2, Map( 0L -> (("Sort", Map( "sort time" -> { @@ -272,7 +272,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils test("SortMergeJoin metrics") { // Because SortMergeJoin may skip different rows if the number of partitions is different, this // test should use the deterministic number of partitions. - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter(Symbol("a") < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.createOrReplaceTempView("testDataForJoin") withTempView("testDataForJoin") { // Assume the execution plan is @@ -298,7 +298,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils test("SortMergeJoin(outer) metrics") { // Because SortMergeJoin may skip different rows if the number of partitions is different, // this test should use the deterministic number of partitions. - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter(Symbol("a") < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.createOrReplaceTempView("testDataForJoin") withTempView("testDataForJoin") { // Assume the execution plan is @@ -441,7 +441,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } test("BroadcastNestedLoopJoin metrics") { - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter(Symbol("a") < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.createOrReplaceTempView("testDataForJoin") withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { withTempView("testDataForJoin") { @@ -494,7 +494,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils test("CartesianProduct metrics") { withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2) + val testDataForJoin = testData2.filter(Symbol("a") < 2) // TestData2(1, 1) :: TestData2(1, 2) testDataForJoin.createOrReplaceTempView("testDataForJoin") withTempView("testDataForJoin") { // Assume the execution plan is @@ -529,7 +529,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils test("save metrics") { withTempPath { file => // person creates a temporary view. get the DF before listing previous execution IDs - val data = person.select('name) + val data = person.select(Symbol("name")) val previousExecutionIds = currentExecutionIds() // Assume the execution plan is // PhysicalRDD(nodeId = 0) @@ -663,7 +663,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { - val df = spark.range(0, 3000, 1, 2).toDF().filter('id % 3 === 0) + val df = spark.range(0, 3000, 1, 2).toDF().filter(Symbol("id") % 3 === 0) df.collect() checkFilterAndRangeMetrics(df, filterNumOutputs = 1000, rangeNumOutputs = 3000) @@ -686,7 +686,8 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { // A special query that only has one partition, so there is no shuffle and the entire query // can be whole-stage-codegened. - val df = spark.range(0, 1500, 1, 1).limit(10).groupBy('id).count().limit(1).filter('id >= 0) + val df = spark.range(0, 1500, 1, 1).limit(10).groupBy(Symbol("id")) + .count().limit(1).filter(Symbol("id") >= 0) df.collect() val plan = df.queryExecution.executedPlan @@ -749,7 +750,7 @@ class SQLMetricsSuite extends SharedSparkSession with SQLMetricsTestUtils } test("SPARK-28332: SQLMetric merge should handle -1 properly") { - val df = testData.join(testData2.filter('b === 0), $"key" === $"a", "left_outer") + val df = testData.join(testData2.filter(Symbol("b") === 0), $"key" === $"a", "left_outer") df.collect() val plan = df.queryExecution.executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala index a508f923ffa1..217c13cc3e95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala @@ -40,8 +40,8 @@ class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { val df = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(df)( @@ -79,8 +79,8 @@ class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { val df = testSource.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long]) /** Reset this test source so that it appears to be a new source requiring initialization */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala index 5884380271f0..11dbf9c2beaa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala @@ -141,7 +141,7 @@ class ConsoleWriteSupportSuite extends StreamTest { .option("numPartitions", "1") .option("rowsPerSecond", "5") .load() - .select('value) + .select(Symbol("value")) val query = input.writeStream.format("console").trigger(Trigger.Continuous(200)).start() assert(query.isActive) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala index 0fe339b93047..46440c98226a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterSuite.scala @@ -165,8 +165,8 @@ class ForeachWriterSuite extends StreamTest with SharedSparkSession with BeforeA val windowedAggregation = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"count".as[Long]) .map(_.toInt) .repartition(1) @@ -199,8 +199,8 @@ class ForeachWriterSuite extends StreamTest with SharedSparkSession with BeforeA val windowedAggregation = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"count".as[Long]) .map(_.toInt) .repartition(1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 6440e69e2ec2..2c1bb41302c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -83,7 +83,7 @@ class RateStreamProviderSuite extends StreamTest { .format("rate") .option("rowsPerSecond", "10") .load() - .select('value) + .select(Symbol("value")) var streamDuration = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala index 13b22dba1168..e57d2a518532 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ExecutorSideSQLConfSuite.scala @@ -81,7 +81,7 @@ class ExecutorSideSQLConfSuite extends SparkFunSuite with SQLTestUtils { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { withTempPath { path => val pathString = path.getCanonicalPath - spark.range(10).select('id.as("ID")).write.json(pathString) + spark.range(10).select(Symbol("id").as("ID")).write.json(pathString) spark.range(10).write.mode("append").json(pathString) assert(spark.read.json(pathString).columns.toSet == Set("id", "ID")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index cc2721149cff..f65a6930131f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -367,7 +367,8 @@ class JDBCSuite extends QueryTest } test("SELECT * WHERE (quoted strings)") { - assert(sql("select * from foobar").where('NAME === "joe 'foo' \"bar\"").collect().size === 1) + assert(sql("select * from foobar") + .where(Symbol("NAME") === "joe 'foo' \"bar\"").collect().size === 1) } test("SELECT first field") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala index 81ce979ef0b6..59239397c6ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala @@ -36,7 +36,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with override def beforeAll(): Unit = { super.beforeAll() - targetAttributes = Seq('a.int, 'd.int, 'b.int, 'c.int) + targetAttributes = Seq(Symbol("a").int, Symbol("d").int, Symbol("b").int, Symbol("c").int) targetPartitionSchema = new StructType() .add("b", IntegerType) .add("c", IntegerType) @@ -74,7 +74,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with caseSensitive) { intercept[AssertionError] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int, 'f.int), + sourceAttributes = Seq(Symbol("e").int, Symbol("f").int), providedPartitions = Map("b" -> None, "c" -> None), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) @@ -85,7 +85,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with // Missing columns. intercept[AnalysisException] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int), + sourceAttributes = Seq(Symbol("e").int), providedPartitions = Map("b" -> Some("1"), "c" -> None), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) @@ -96,7 +96,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with // Missing partitioning columns. intercept[AnalysisException] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int, 'f.int), + sourceAttributes = Seq(Symbol("e").int, Symbol("f").int), providedPartitions = Map("b" -> Some("1")), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) @@ -105,7 +105,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with // Missing partitioning columns. intercept[AnalysisException] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int, 'f.int, 'g.int), + sourceAttributes = Seq(Symbol("e").int, Symbol("f").int, Symbol("g").int), providedPartitions = Map("b" -> Some("1")), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) @@ -114,7 +114,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with // Wrong partitioning columns. intercept[AnalysisException] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int, 'f.int), + sourceAttributes = Seq(Symbol("e").int, Symbol("f").int), providedPartitions = Map("b" -> Some("1"), "d" -> None), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) @@ -125,7 +125,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with // Wrong partitioning columns. intercept[AnalysisException] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int, 'f.int), + sourceAttributes = Seq(Symbol("e").int, Symbol("f").int), providedPartitions = Map("b" -> Some("1"), "d" -> Some("2")), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) @@ -134,7 +134,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with // Wrong partitioning columns. intercept[AnalysisException] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int), + sourceAttributes = Seq(Symbol("e").int), providedPartitions = Map("b" -> Some("1"), "c" -> Some("3"), "d" -> Some("2")), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) @@ -144,7 +144,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with // Wrong partitioning columns. intercept[AnalysisException] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int, 'f.int), + sourceAttributes = Seq(Symbol("e").int, Symbol("f").int), providedPartitions = Map("b" -> Some("1"), "C" -> Some("3")), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) @@ -156,7 +156,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with // Static partitions need to appear before dynamic partitions. intercept[AnalysisException] { rule.convertStaticPartitions( - sourceAttributes = Seq('e.int, 'f.int), + sourceAttributes = Seq(Symbol("e").int, Symbol("f").int), providedPartitions = Map("b" -> None, "c" -> Some("3")), targetAttributes = targetAttributes, targetPartitionSchema = targetPartitionSchema) @@ -165,7 +165,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with testRule("All static partitions", caseSensitive) { if (!caseSensitive) { - val nonPartitionedAttributes = Seq('e.int, 'f.int) + val nonPartitionedAttributes = Seq(Symbol("e").int, Symbol("f").int) val expected = nonPartitionedAttributes ++ Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType)) val actual = rule.convertStaticPartitions( @@ -177,7 +177,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with } { - val nonPartitionedAttributes = Seq('e.int, 'f.int) + val nonPartitionedAttributes = Seq(Symbol("e").int, Symbol("f").int) val expected = nonPartitionedAttributes ++ Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType)) val actual = rule.convertStaticPartitions( @@ -190,20 +190,20 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll with // Test the case having a single static partition column. { - val nonPartitionedAttributes = Seq('e.int, 'f.int) + val nonPartitionedAttributes = Seq(Symbol("e").int, Symbol("f").int) val expected = nonPartitionedAttributes ++ Seq(cast(Literal("1"), IntegerType)) val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes, providedPartitions = Map("b" -> Some("1")), - targetAttributes = Seq('a.int, 'd.int, 'b.int), + targetAttributes = Seq(Symbol("a").int, Symbol("d").int, Symbol("b").int), targetPartitionSchema = new StructType().add("b", IntegerType)) checkProjectList(actual, expected) } } testRule("Static partition and dynamic partition", caseSensitive) { - val nonPartitionedAttributes = Seq('e.int, 'f.int) - val dynamicPartitionAttributes = Seq('g.int) + val nonPartitionedAttributes = Seq(Symbol("e").int, Symbol("f").int) + val dynamicPartitionAttributes = Seq(Symbol("g").int) val expected = nonPartitionedAttributes ++ Seq(cast(Literal("1"), IntegerType)) ++ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index 67ab72a79145..e0b808f67514 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -132,8 +132,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val inputData1 = MemoryStream[Int] val aggWithoutWatermark = inputData1.toDF() .withColumn("eventTime", timestamp_seconds($"value")) - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(aggWithoutWatermark, outputMode = Complete)( @@ -150,8 +150,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val aggWithWatermark = inputData2.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(aggWithWatermark)( @@ -173,8 +173,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val aggWithWatermark = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) // Unlike the ProcessingTime trigger, Trigger.Once only runs one trigger every time @@ -228,8 +228,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val aggWithWatermark = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) @@ -290,8 +290,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val windowedAggregation = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(windowedAggregation)( @@ -315,8 +315,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val windowedAggregation = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(windowedAggregation, OutputMode.Update)( @@ -345,8 +345,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val aggWithWatermark = input.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "2 years 5 months") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) def monthsSinceEpoch(date: Date): Int = { @@ -377,8 +377,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val df = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(df)( @@ -412,17 +412,17 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val firstDf = first.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .select('value) + .select(Symbol("value")) val second = MemoryStream[Int] val secondDf = second.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "5 seconds") - .select('value) + .select(Symbol("value")) withTempDir { checkpointDir => - val unionWriter = firstDf.union(secondDf).agg(sum('value)) + val unionWriter = firstDf.union(secondDf).agg(sum(Symbol("value"))) .writeStream .option("checkpointLocation", checkpointDir.getCanonicalPath) .format("memory") @@ -489,8 +489,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val windowedAggregation = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) // No eviction when asked to compute complete results. @@ -515,7 +515,7 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") .groupBy($"eventTime") - .agg(count("*") as 'count) + .agg(count("*") as Symbol("count")) .select($"eventTime".cast("long").as[Long], $"count".as[Long]) testStream(windowedAggregation)( @@ -586,7 +586,7 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val groupEvents = input .withWatermark("eventTime", "2 seconds") .groupBy("symbol", "eventTime") - .agg(count("price") as 'count) + .agg(count("price") as Symbol("count")) .select("symbol", "eventTime", "count") val q = groupEvents.writeStream .outputMode("append") @@ -605,14 +605,14 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val aliasWindow = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .select(window($"eventTime", "5 seconds") as 'aliasWindow) + .select(window($"eventTime", "5 seconds") as Symbol("aliasWindow")) // Check the eventTime metadata is kept in the top level alias. assert(aliasWindow.logicalPlan.output.exists( _.metadata.contains(EventTimeWatermark.delayKey))) val windowedAggregation = aliasWindow - .groupBy('aliasWindow) - .agg(count("*") as 'count) + .groupBy(Symbol("aliasWindow")) + .agg(count("*") as Symbol("count")) .select($"aliasWindow".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(windowedAggregation)( @@ -635,8 +635,8 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Matche val windowedAggregation = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(windowedAggregation)( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index c4e43d24b0b8..516ffa44a996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -216,7 +216,7 @@ class StreamSuite extends StreamTest { query.processAllAvailable() // Parquet write page-level CRC checksums will change the file size and // affect the data order when reading these files. Please see PARQUET-1746 for details. - val outputDf = spark.read.parquet(outputDir.getAbsolutePath).sort('a).as[Long] + val outputDf = spark.read.parquet(outputDir.getAbsolutePath).sort(Symbol("a")).as[Long] checkDataset[Long](outputDf, (0L to 10L).toArray: _*) } finally { query.stop() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 491b0d8b2c26..ced89c4587a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -109,7 +109,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { val aggregated = inputData.toDF() - .select($"*", explode($"_2") as 'value) + .select($"*", explode($"_2") as Symbol("value")) .groupBy($"_1") .agg(size(collect_set($"value"))) .as[(Int, Int)] @@ -190,8 +190,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { val aggWithWatermark = inputData.toDF() .withColumn("eventTime", timestamp_seconds($"value")) .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) implicit class RichStreamExecution(query: StreamExecution) { @@ -414,7 +414,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { inputData.toDF() .groupBy($"value") .agg(count("*")) - .where('value >= current_timestamp().cast("long") - 10L) + .where(Symbol("value") >= current_timestamp().cast("long") - 10L) testStream(aggregated, Complete)( StartStream(Trigger.ProcessingTime("10 seconds"), triggerClock = clock), @@ -465,7 +465,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { val inputData = MemoryStream[Long] val aggregated = inputData.toDF() - .select(to_utc_timestamp(from_unixtime('value * SECONDS_PER_DAY), tz)) + .select(to_utc_timestamp(from_unixtime(Symbol("value") * SECONDS_PER_DAY), tz)) .toDF("value") .groupBy($"value") .agg(count("*")) @@ -512,12 +512,12 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { val streamInput = MemoryStream[Int] val batchDF = Seq(1, 2, 3, 4, 5) .toDF("value") - .withColumn("parity", 'value % 2) - .groupBy('parity) - .agg(count("*") as 'joinValue) + .withColumn("parity", Symbol("value") % 2) + .groupBy(Symbol("parity")) + .agg(count("*") as Symbol("joinValue")) val joinDF = streamInput .toDF() - .join(batchDF, 'value === 'parity) + .join(batchDF, Symbol("value") === Symbol("parity")) // make sure we're planning an aggregate in the first place assert(batchDF.queryExecution.optimizedPlan match { case _: Aggregate => true }) @@ -629,7 +629,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { def createDf(partitions: Int): Dataset[(Long, Long)] = { spark.readStream .format((new MockSourceProvider).getClass.getCanonicalName) - .load().coalesce(partitions).groupBy('a % 1).count().as[(Long, Long)] + .load().coalesce(partitions).groupBy(Symbol("a") % 1).count().as[(Long, Long)] } testStream(createDf(1), Complete())( @@ -667,7 +667,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with Assertions { testWithAllStateVersions("SPARK-22230: last should change with new batches") { val input = MemoryStream[Int] - val aggregated = input.toDF().agg(last('value)) + val aggregated = input.toDF().agg(last(Symbol("value"))) testStream(aggregated, OutputMode.Complete())( AddData(input, 1, 2, 3), CheckLastBatch(3), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala index ac9cd1a12d06..7a7c311e56c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingDeduplicationSuite.scala @@ -114,8 +114,8 @@ class StreamingDeduplicationSuite extends StateStoreMetricsTest { .withWatermark("eventTime", "10 seconds") .dropDuplicates() .withWatermark("eventTime", "10 seconds") - .groupBy(window($"eventTime", "5 seconds") as 'window) - .agg(count("*") as 'count) + .groupBy(window($"eventTime", "5 seconds") as Symbol("window")) + .agg(count("*") as Symbol("count")) .select($"window".getField("start").cast("long").as[Long], $"count".as[Long]) testStream(windowedaggregate)( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 40131e822c5c..d926d28f0c49 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -51,9 +51,9 @@ abstract class StreamingJoinSuite val input = MemoryStream[Int] val df = input.toDF .select( - 'value as "key", + Symbol("value") as "key", timestamp_seconds($"value") as s"${prefix}Time", - ('value * multiplier) as s"${prefix}Value") + (Symbol("value") * multiplier) as s"${prefix}Value") .withWatermark(s"${prefix}Time", "10 seconds") (input, df) @@ -64,13 +64,16 @@ abstract class StreamingJoinSuite val (input1, df1) = setupStream("left", 2) val (input2, df2) = setupStream("right", 3) - val windowed1 = df1.select('key, window('leftTime, "10 second"), 'leftValue) - val windowed2 = df2.select('key, window('rightTime, "10 second"), 'rightValue) + val windowed1 = df1.select(Symbol("key"), + window(Symbol("leftTime"), "10 second"), Symbol("leftValue")) + val windowed2 = df2.select(Symbol("key"), + window(Symbol("rightTime"), "10 second"), Symbol("rightValue")) val joined = windowed1.join(windowed2, Seq("key", "window"), joinType) val select = if (joinType == "left_semi") { - joined.select('key, $"window.end".cast("long"), 'leftValue) + joined.select(Symbol("key"), $"window.end".cast("long"), Symbol("leftValue")) } else { - joined.select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + joined.select(Symbol("key"), $"window.end".cast("long"), + Symbol("leftValue"), Symbol("rightValue")) } (input1, input2, select) @@ -82,25 +85,29 @@ abstract class StreamingJoinSuite val (leftInput, df1) = setupStream("left", 2) val (rightInput, df2) = setupStream("right", 3) // Use different schemas to ensure the null row is being generated from the correct side. - val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) - val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + val left = df1.select(Symbol("key"), + window(Symbol("leftTime"), "10 second"), Symbol("leftValue")) + val right = df2.select(Symbol("key"), + window(Symbol("rightTime"), "10 second"), Symbol("rightValue").cast("string")) val joined = left.join( right, left("key") === right("key") && left("window") === right("window") - && 'leftValue > 4, + && Symbol("leftValue") > 4, joinType) val select = if (joinType == "left_semi") { - joined.select(left("key"), left("window.end").cast("long"), 'leftValue) + joined.select(left("key"), left("window.end").cast("long"), Symbol("leftValue")) } else if (joinType == "left_outer") { - joined.select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + joined.select(left("key"), left("window.end").cast("long"), + Symbol("leftValue"), Symbol("rightValue")) } else if (joinType == "right_outer") { - joined.select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) + joined.select(right("key"), right("window.end").cast("long"), + Symbol("leftValue"), Symbol("rightValue")) } else { - joined.select(left("key"), left("window.end").cast("long"), 'leftValue, - right("key"), right("window.end").cast("long"), 'rightValue) + joined.select(left("key"), left("window.end").cast("long"), Symbol("leftValue"), + right("key"), right("window.end").cast("long"), Symbol("rightValue")) } (leftInput, rightInput, select) @@ -112,25 +119,29 @@ abstract class StreamingJoinSuite val (leftInput, df1) = setupStream("left", 2) val (rightInput, df2) = setupStream("right", 3) // Use different schemas to ensure the null row is being generated from the correct side. - val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) - val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + val left = + df1.select(Symbol("key"), window(Symbol("leftTime"), "10 second"), Symbol("leftValue")) + val right = df2.select(Symbol("key"), + window(Symbol("rightTime"), "10 second"), Symbol("rightValue").cast("string")) val joined = left.join( right, left("key") === right("key") && left("window") === right("window") - && 'rightValue.cast("int") > 7, + && Symbol("rightValue").cast("int") > 7, joinType) val select = if (joinType == "left_semi") { - joined.select(left("key"), left("window.end").cast("long"), 'leftValue) + joined.select(left("key"), left("window.end").cast("long"), Symbol("leftValue")) } else if (joinType == "left_outer") { - joined.select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + joined.select(left("key"), left("window.end").cast("long"), + Symbol("leftValue"), Symbol("rightValue")) } else if (joinType == "right_outer") { - joined.select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) + joined.select(right("key"), right("window.end").cast("long"), + Symbol("leftValue"), Symbol("rightValue")) } else { - joined.select(left("key"), left("window.end").cast("long"), 'leftValue, - right("key"), right("window.end").cast("long"), 'rightValue) + joined.select(left("key"), left("window.end").cast("long"), Symbol("leftValue"), + right("key"), right("window.end").cast("long"), Symbol("rightValue")) } (leftInput, rightInput, select) @@ -143,12 +154,13 @@ abstract class StreamingJoinSuite val rightInput = MemoryStream[(Int, Int)] val df1 = leftInput.toDF.toDF("leftKey", "time") - .select('leftKey, timestamp_seconds($"time") as "leftTime", ('leftKey * 2) as "leftValue") + .select(Symbol("leftKey"), + timestamp_seconds($"time") as "leftTime", (Symbol("leftKey") * 2) as "leftValue") .withWatermark("leftTime", "10 seconds") val df2 = rightInput.toDF.toDF("rightKey", "time") - .select('rightKey, timestamp_seconds($"time") as "rightTime", - ('rightKey * 3) as "rightValue") + .select(Symbol("rightKey"), timestamp_seconds($"time") as "rightTime", + (Symbol("rightKey") * 3) as "rightValue") .withWatermark("rightTime", "10 seconds") val joined = @@ -159,9 +171,10 @@ abstract class StreamingJoinSuite joinType) val select = if (joinType == "left_semi") { - joined.select('leftKey, 'leftTime.cast("int")) + joined.select(Symbol("leftKey"), Symbol("leftTime").cast("int")) } else { - joined.select('leftKey, 'rightKey, 'leftTime.cast("int"), 'rightTime.cast("int")) + joined.select(Symbol("leftKey"), Symbol("rightKey"), + Symbol("leftTime").cast("int"), Symbol("rightTime").cast("int")) } (leftInput, rightInput, select) @@ -208,8 +221,8 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val input1 = MemoryStream[Int] val input2 = MemoryStream[Int] - val df1 = input1.toDF.select('value as "key", ('value * 2) as "leftValue") - val df2 = input2.toDF.select('value as "key", ('value * 3) as "rightValue") + val df1 = input1.toDF.select(Symbol("value") as "key", (Symbol("value") * 2) as "leftValue") + val df2 = input2.toDF.select(Symbol("value") as "key", (Symbol("value") * 3) as "rightValue") val joined = df1.join(df2, "key") testStream(joined)( @@ -238,17 +251,17 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val input2 = MemoryStream[Int] val df1 = input1.toDF - .select('value as "key", timestamp_seconds($"value") as "timestamp", - ('value * 2) as "leftValue") - .select('key, window('timestamp, "10 second"), 'leftValue) + .select(Symbol("value") as "key", timestamp_seconds($"value") as "timestamp", + (Symbol("value") * 2) as "leftValue") + .select(Symbol("key"), window(Symbol("timestamp"), "10 second"), Symbol("leftValue")) val df2 = input2.toDF - .select('value as "key", timestamp_seconds($"value") as "timestamp", - ('value * 3) as "rightValue") - .select('key, window('timestamp, "10 second"), 'rightValue) + .select(Symbol("value") as "key", timestamp_seconds($"value") as "timestamp", + (Symbol("value") * 3) as "rightValue") + .select(Symbol("key"), window(Symbol("timestamp"), "10 second"), Symbol("rightValue")) val joined = df1.join(df2, Seq("key", "window")) - .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + .select(Symbol("key"), $"window.end".cast("long"), Symbol("leftValue"), Symbol("rightValue")) testStream(joined)( AddData(input1, 1), @@ -279,18 +292,18 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val input2 = MemoryStream[Int] val df1 = input1.toDF - .select('value as "key", timestamp_seconds($"value") as "timestamp", - ('value * 2) as "leftValue") + .select(Symbol("value") as "key", timestamp_seconds($"value") as "timestamp", + (Symbol("value") * 2) as "leftValue") .withWatermark("timestamp", "10 seconds") - .select('key, window('timestamp, "10 second"), 'leftValue) + .select(Symbol("key"), window(Symbol("timestamp"), "10 second"), Symbol("leftValue")) val df2 = input2.toDF - .select('value as "key", timestamp_seconds($"value") as "timestamp", - ('value * 3) as "rightValue") - .select('key, window('timestamp, "10 second"), 'rightValue) + .select(Symbol("value") as "key", timestamp_seconds($"value") as "timestamp", + (Symbol("value") * 3) as "rightValue") + .select(Symbol("key"), window(Symbol("timestamp"), "10 second"), Symbol("rightValue")) val joined = df1.join(df2, Seq("key", "window")) - .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + .select(Symbol("key"), $"window.end".cast("long"), Symbol("leftValue"), Symbol("rightValue")) testStream(joined)( AddData(input1, 1), @@ -330,17 +343,18 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val rightInput = MemoryStream[(Int, Int)] val df1 = leftInput.toDF.toDF("leftKey", "time") - .select('leftKey, timestamp_seconds($"time") as "leftTime", ('leftKey * 2) as "leftValue") + .select(Symbol("leftKey"), + timestamp_seconds($"time") as "leftTime", (Symbol("leftKey") * 2) as "leftValue") .withWatermark("leftTime", "10 seconds") val df2 = rightInput.toDF.toDF("rightKey", "time") - .select('rightKey, timestamp_seconds($"time") as "rightTime", - ('rightKey * 3) as "rightValue") + .select(Symbol("rightKey"), timestamp_seconds($"time") as "rightTime", + (Symbol("rightKey") * 3) as "rightValue") .withWatermark("rightTime", "10 seconds") val joined = df1.join(df2, expr("leftKey = rightKey AND leftTime < rightTime - interval 5 seconds")) - .select('leftKey, 'leftTime.cast("int"), 'rightTime.cast("int")) + .select(Symbol("leftKey"), Symbol("leftTime").cast("int"), Symbol("rightTime").cast("int")) testStream(joined)( AddData(leftInput, (1, 5)), @@ -389,12 +403,13 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val rightInput = MemoryStream[(Int, Int)] val df1 = leftInput.toDF.toDF("leftKey", "time") - .select('leftKey, timestamp_seconds($"time") as "leftTime", ('leftKey * 2) as "leftValue") + .select(Symbol("leftKey"), + timestamp_seconds($"time") as "leftTime", (Symbol("leftKey") * 2) as "leftValue") .withWatermark("leftTime", "20 seconds") val df2 = rightInput.toDF.toDF("rightKey", "time") - .select('rightKey, timestamp_seconds($"time") as "rightTime", - ('rightKey * 3) as "rightValue") + .select(Symbol("rightKey"), timestamp_seconds($"time") as "rightTime", + (Symbol("rightKey") * 3) as "rightValue") .withWatermark("rightTime", "30 seconds") val condition = expr( @@ -422,8 +437,8 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { // values allowed: rightTime >= leftTime - 5s ==> rightTime > eventTimeWatermark - 5 // drop state where rightTime < eventTime - 5 - val joined = - df1.join(df2, condition).select('leftKey, 'leftTime.cast("int"), 'rightTime.cast("int")) + val joined = df1.join(df2, condition).select(Symbol("leftKey"), + Symbol("leftTime").cast("int"), Symbol("rightTime").cast("int")) testStream(joined)( // If leftTime = 20, then it match only with rightTime = [15, 30] @@ -470,8 +485,10 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val input1 = MemoryStream[Int] val input2 = MemoryStream[Int] - val df1 = input1.toDF.select('value as "leftKey", ('value * 2) as "leftValue") - val df2 = input2.toDF.select('value as "rightKey", ('value * 3) as "rightValue") + val df1 = + input1.toDF.select(Symbol("value") as "leftKey", (Symbol("value") * 2) as "leftValue") + val df2 = + input2.toDF.select(Symbol("value") as "rightKey", (Symbol("value") * 3) as "rightValue") val joined = df1.join(df2, expr("leftKey < rightKey")) val e = intercept[Exception] { val q = joined.writeStream.format("memory").queryName("test").start() @@ -485,8 +502,8 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val input = MemoryStream[Int] val df = input.toDF val join = - df.select('value % 5 as "key", 'value).join( - df.select('value % 5 as "key", 'value), "key") + df.select(Symbol("value") % 5 as "key", Symbol("value")).join( + df.select(Symbol("value") % 5 as "key", Symbol("value")), "key") testStream(join)( AddData(input, 1, 2), @@ -550,9 +567,12 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val input2 = MemoryStream[Int] val input3 = MemoryStream[Int] - val df1 = input1.toDF.select('value as "leftKey", ('value * 2) as "leftValue") - val df2 = input2.toDF.select('value as "middleKey", ('value * 3) as "middleValue") - val df3 = input3.toDF.select('value as "rightKey", ('value * 5) as "rightValue") + val df1 = + input1.toDF.select(Symbol("value") as "leftKey", (Symbol("value") * 2) as "leftValue") + val df2 = + input2.toDF.select(Symbol("value") as "middleKey", (Symbol("value") * 3) as "middleValue") + val df3 = + input3.toDF.select(Symbol("value") as "rightKey", (Symbol("value") * 5) as "rightValue") val joined = df1.join(df2, expr("leftKey = middleKey")).join(df3, expr("rightKey = middleKey")) @@ -567,9 +587,11 @@ class StreamingInnerJoinSuite extends StreamingJoinSuite { val input1 = MemoryStream[Int] val input2 = MemoryStream[Int] - val df1 = input1.toDF.select('value as 'a, 'value * 2 as 'b) - val df2 = input2.toDF.select('value as 'a, 'value * 2 as 'b).repartition('b) - val joined = df1.join(df2, Seq("a", "b")).select('a) + val df1 = + input1.toDF.select(Symbol("value") as Symbol("a"), Symbol("value") * 2 as Symbol("b")) + val df2 = input2.toDF.select( + Symbol("value") as Symbol("a"), Symbol("value") * 2 as Symbol("b")).repartition(Symbol("b")) + val joined = df1.join(df2, Seq("a", "b")).select(Symbol("a")) testStream(joined)( AddData(input1, 1.to(1000): _*), @@ -778,15 +800,18 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite { val (leftInput, simpleLeftDf) = setupStream("left", 2) val (rightInput, simpleRightDf) = setupStream("right", 3) - val left = simpleLeftDf.select('key, window('leftTime, "10 second"), 'leftValue) - val right = simpleRightDf.select('key, window('rightTime, "10 second"), 'rightValue) + val left = simpleLeftDf.select( + Symbol("key"), window(Symbol("leftTime"), "10 second"), Symbol("leftValue")) + val right = simpleRightDf.select( + Symbol("key"), window(Symbol("rightTime"), "10 second"), Symbol("rightValue")) val joined = left.join( - right, - left("key") === right("key") && left("window") === right("window") && - 'leftValue > 10 && ('rightValue < 300 || 'rightValue > 1000), - "left_outer") - .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + right, + left("key") === right("key") && left("window") === right("window") && + Symbol("leftValue") > 10 && + (Symbol("rightValue") < 300 || Symbol("rightValue") > 1000), "left_outer") + .select( + left("key"), left("window.end").cast("long"), Symbol("leftValue"), Symbol("rightValue")) testStream(joined)( // leftValue <= 10 should generate outer join rows even though it matches right keys @@ -977,9 +1002,9 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite { val input1 = MemoryStream[Int](desiredPartitionsForInput1) val df1 = input1.toDF .select( - 'value as "key", - 'value as "leftValue", - 'value as "rightValue") + Symbol("value") as "key", + Symbol("value") as "leftValue", + Symbol("value") as "rightValue") val (input2, df2) = setupStream("left", 2) val (input3, df3) = setupStream("right", 3) @@ -987,7 +1012,7 @@ class StreamingOuterJoinSuite extends StreamingJoinSuite { .join(df3, df2("key") === df3("key") && df2("leftTime") === df3("rightTime"), "inner") - .select(df2("key"), 'leftValue, 'rightValue) + .select(df2("key"), Symbol("leftValue"), Symbol("rightValue")) (input1, input2, input3, df1.union(joined)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala index 3892caa51eca..b656f2072462 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveScriptTransformationSuite.scala @@ -176,12 +176,12 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T """.stripMargin), identity, df.select( - 'a.cast("string").as("key"), + Symbol("a").cast("string").as("key"), concat_ws("\t", - 'b.cast("string"), - 'c.cast("string"), - 'd.cast("string"), - 'e.cast("string")).as("value")).collect()) + Symbol("b").cast("string"), + Symbol("c").cast("string"), + Symbol("d").cast("string"), + Symbol("e").cast("string")).as("value")).collect()) // In hive default serde mode, if we don't define output schema, // when output column size > 2 and just specify serde, @@ -204,8 +204,8 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T """.stripMargin), identity, df.select( - 'a.cast("string").as("key"), - 'b.cast("string").as("value")).collect()) + Symbol("a").cast("string").as("key"), + Symbol("b").cast("string").as("value")).collect()) // In hive default serde mode, if we don't define output schema, @@ -232,12 +232,12 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T """.stripMargin), identity, df.select( - 'a.cast("string").as("key"), + Symbol("a").cast("string").as("key"), concat_ws("\t", - 'b.cast("string"), - 'c.cast("string"), - 'd.cast("string"), - 'e.cast("string")).as("value")).collect()) + Symbol("b").cast("string"), + Symbol("c").cast("string"), + Symbol("d").cast("string"), + Symbol("e").cast("string")).as("value")).collect()) // In hive default serde mode, if we don't define output schema, // when output column size > 2 and specify serde @@ -262,8 +262,8 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T """.stripMargin), identity, df.select( - 'a.cast("string").as("key"), - 'b.cast("string").as("value")).collect()) + Symbol("a").cast("string").as("key"), + Symbol("b").cast("string").as("value")).collect()) // In hive default serde mode, if we don't define output schema, // when output column size = 2 and specify serde, it will these two column as @@ -287,8 +287,8 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T """.stripMargin), identity, df.select( - 'a.cast("string").as("key"), - 'b.cast("string").as("value")).collect()) + Symbol("a").cast("string").as("key"), + Symbol("b").cast("string").as("value")).collect()) // In hive default serde mode, if we don't define output schema, // when output column size < 2 and specify serde, it will return null for deficiency @@ -312,7 +312,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T """.stripMargin), identity, df.select( - 'a.cast("string").as("key"), + Symbol("a").cast("string").as("key"), lit(null)).collect()) } } @@ -325,8 +325,9 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T val df = Seq( (1, "1", Array(0, 1, 2), Map("a" -> 1)), (2, "2", Array(3, 4, 5), Map("b" -> 2)) - ).toDF("a", "b", "c", "d") - .select('a, 'b, 'c, 'd, struct('a, 'b).as("e")) + ).toDF("a", "b", "c", "d").select( + Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), + struct(Symbol("a"), Symbol("b")).as("e")) df.createTempView("v") // Hive serde support ArrayType/MapType/StructType as input and output data type @@ -348,7 +349,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T child = child, ioschema = hiveIOSchema ), - df.select('c, 'd, 'e).collect()) + df.select(Symbol("c"), Symbol("d"), Symbol("e")).collect()) } } @@ -358,8 +359,9 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T val df = Seq( (1, "1", Array(0, 1, 2), Map("a" -> 1)), (2, "2", Array(3, 4, 5), Map("b" -> 2)) - ).toDF("a", "b", "c", "d") - .select('a, 'b, 'c, 'd, struct('a, 'b).as("e")) + ).toDF("a", "b", "c", "d").select( + Symbol("a"), Symbol("b"), Symbol("c"), Symbol("d"), + struct(Symbol("a"), Symbol("b")).as("e")) df.createTempView("v") // Hive serde support ArrayType/MapType/StructType as input and output data type @@ -369,7 +371,7 @@ class HiveScriptTransformationSuite extends BaseScriptTransformationSuite with T |USING 'cat' AS (c array, d map, e struct) |FROM v """.stripMargin) - checkAnswer(query, identity, df.select('c, 'd, 'e).collect()) + checkAnswer(query, identity, df.select(Symbol("c"), Symbol("d"), Symbol("e")).collect()) } }