Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
// key "b" => (19:39:27 ~ 19:39:37)

checkAnswer(
df.groupBy(session_window($"time", "10 seconds"), 'id)
df.groupBy(session_window($"time", "10 seconds"), Symbol("id"))
.agg(count("*").as("counts"), sum("value").as("sum"))
.orderBy($"session_window.start".asc)
.selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)",
Expand Down Expand Up @@ -113,7 +113,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
// key "b" => (19:39:27 ~ 19:39:37)

checkAnswer(
df.groupBy(session_window($"time", "10 seconds"), 'id)
df.groupBy(session_window($"time", "10 seconds"), Symbol("id"))
.agg(count("*").as("counts"), sum_distinct(col("value")).as("sum"))
.orderBy($"session_window.start".asc)
.selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)",
Expand Down Expand Up @@ -142,7 +142,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
// key "b" => (19:39:27 ~ 19:39:37)

checkAnswer(
df.groupBy(session_window($"time", "10 seconds"), 'id)
df.groupBy(session_window($"time", "10 seconds"), Symbol("id"))
.agg(sum_distinct(col("value")).as("sum"), sum_distinct(col("value2")).as("sum2"))
.orderBy($"session_window.start".asc)
.selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)",
Expand Down Expand Up @@ -171,7 +171,7 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
// b => (19:39:27 ~ 19:39:37), (19:39:39 ~ 19:39:55)

checkAnswer(
df.groupBy(session_window($"time", "10 seconds"), 'id)
df.groupBy(session_window($"time", "10 seconds"), Symbol("id"))
.agg(count("*").as("counts"), sum("value").as("sum"))
.orderBy($"session_window.start".asc)
.selectExpr("CAST(session_window.start AS STRING)", "CAST(session_window.end AS STRING)",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit
}

test("SPARK-35884: Explain should only display one plan before AQE takes effect") {
val df = (0 to 10).toDF("id").where('id > 5)
val df = (0 to 10).toDF("id").where(Symbol("id") > 5)
val modes = Seq(SimpleMode, ExtendedMode, CostMode, FormattedMode)
modes.foreach { mode =>
checkKeywordsExistsInExplain(df, mode, "AdaptiveSparkPlan")
Expand All @@ -608,7 +608,8 @@ class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuit

test("SPARK-35884: Explain formatted with subquery") {
withTempView("t1", "t2") {
spark.range(100).select('id % 10 as "key", 'id as "value").createOrReplaceTempView("t1")
spark.range(100).select(Symbol("id") % 10 as "key", Symbol("id") as "value")
.createOrReplaceTempView("t1")
spark.range(10).createOrReplaceTempView("t2")
val query =
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -967,52 +967,57 @@ class FileBasedDataSourceSuite extends QueryTest

// cases when value == MAX
var v = Short.MaxValue
checkPushedFilters(format, df.where('id > v.toInt), Array(), noScan = true)
checkPushedFilters(format, df.where('id >= v.toInt), Array(sources.IsNotNull("id"),
sources.EqualTo("id", v)))
checkPushedFilters(format, df.where('id === v.toInt), Array(sources.IsNotNull("id"),
sources.EqualTo("id", v)))
checkPushedFilters(format, df.where('id <=> v.toInt),
checkPushedFilters(format, df.where(Symbol("id") > v.toInt), Array(), noScan = true)
checkPushedFilters(format, df.where(Symbol("id") >= v.toInt),
Array(sources.IsNotNull("id"), sources.EqualTo("id", v)))
checkPushedFilters(format, df.where(Symbol("id") === v.toInt),
Array(sources.IsNotNull("id"), sources.EqualTo("id", v)))
checkPushedFilters(format, df.where(Symbol("id") <=> v.toInt),
Array(sources.EqualNullSafe("id", v)))
checkPushedFilters(format, df.where('id <= v.toInt), Array(sources.IsNotNull("id")))
checkPushedFilters(format, df.where('id < v.toInt), Array(sources.IsNotNull("id"),
sources.Not(sources.EqualTo("id", v))))
checkPushedFilters(format, df.where(Symbol("id") <= v.toInt),
Array(sources.IsNotNull("id")))
checkPushedFilters(format, df.where(Symbol("id") < v.toInt),
Array(sources.IsNotNull("id"), sources.Not(sources.EqualTo("id", v))))

// cases when value > MAX
var v1: Int = positiveInt
checkPushedFilters(format, df.where('id > v1), Array(), noScan = true)
checkPushedFilters(format, df.where('id >= v1), Array(), noScan = true)
checkPushedFilters(format, df.where('id === v1), Array(), noScan = true)
checkPushedFilters(format, df.where('id <=> v1), Array(), noScan = true)
checkPushedFilters(format, df.where('id <= v1), Array(sources.IsNotNull("id")))
checkPushedFilters(format, df.where('id < v1), Array(sources.IsNotNull("id")))
checkPushedFilters(format, df.where(Symbol("id") > v1), Array(), noScan = true)
checkPushedFilters(format, df.where(Symbol("id") >= v1), Array(), noScan = true)
checkPushedFilters(format, df.where(Symbol("id") === v1), Array(), noScan = true)
checkPushedFilters(format, df.where(Symbol("id") <=> v1), Array(), noScan = true)
checkPushedFilters(format, df.where(Symbol("id") <= v1), Array(sources.IsNotNull("id")))
checkPushedFilters(format, df.where(Symbol("id") < v1), Array(sources.IsNotNull("id")))

// cases when value = MIN
v = Short.MinValue
checkPushedFilters(format, df.where(lit(v.toInt) < 'id), Array(sources.IsNotNull("id"),
sources.Not(sources.EqualTo("id", v))))
checkPushedFilters(format, df.where(lit(v.toInt) <= 'id), Array(sources.IsNotNull("id")))
checkPushedFilters(format, df.where(lit(v.toInt) === 'id), Array(sources.IsNotNull("id"),
checkPushedFilters(format, df.where(lit(v.toInt) < Symbol("id")),
Array(sources.IsNotNull("id"), sources.Not(sources.EqualTo("id", v))))
checkPushedFilters(format, df.where(lit(v.toInt) <= Symbol("id")),
Array(sources.IsNotNull("id")))
checkPushedFilters(format, df.where(lit(v.toInt) === Symbol("id")),
Array(sources.IsNotNull("id"),
sources.EqualTo("id", v)))
checkPushedFilters(format, df.where(lit(v.toInt) <=> 'id),
checkPushedFilters(format, df.where(lit(v.toInt) <=> Symbol("id")),
Array(sources.EqualNullSafe("id", v)))
checkPushedFilters(format, df.where(lit(v.toInt) >= 'id), Array(sources.IsNotNull("id"),
sources.EqualTo("id", v)))
checkPushedFilters(format, df.where(lit(v.toInt) > 'id), Array(), noScan = true)
checkPushedFilters(format, df.where(lit(v.toInt) >= Symbol("id")),
Array(sources.IsNotNull("id"), sources.EqualTo("id", v)))
checkPushedFilters(format, df.where(lit(v.toInt) > Symbol("id")), Array(), noScan = true)

// cases when value < MIN
v1 = negativeInt
checkPushedFilters(format, df.where(lit(v1) < 'id), Array(sources.IsNotNull("id")))
checkPushedFilters(format, df.where(lit(v1) <= 'id), Array(sources.IsNotNull("id")))
checkPushedFilters(format, df.where(lit(v1) === 'id), Array(), noScan = true)
checkPushedFilters(format, df.where(lit(v1) >= 'id), Array(), noScan = true)
checkPushedFilters(format, df.where(lit(v1) > 'id), Array(), noScan = true)
checkPushedFilters(format, df.where(lit(v1) < Symbol("id")),
Array(sources.IsNotNull("id")))
checkPushedFilters(format, df.where(lit(v1) <= Symbol("id")),
Array(sources.IsNotNull("id")))
checkPushedFilters(format, df.where(lit(v1) === Symbol("id")), Array(), noScan = true)
checkPushedFilters(format, df.where(lit(v1) >= Symbol("id")), Array(), noScan = true)
checkPushedFilters(format, df.where(lit(v1) > Symbol("id")), Array(), noScan = true)

// cases when value is within range (MIN, MAX)
checkPushedFilters(format, df.where('id > 30), Array(sources.IsNotNull("id"),
checkPushedFilters(format, df.where(Symbol("id") > 30), Array(sources.IsNotNull("id"),
sources.GreaterThan("id", 30)))
checkPushedFilters(format, df.where(lit(100) >= 'id), Array(sources.IsNotNull("id"),
sources.LessThanOrEqual("id", 100)))
checkPushedFilters(format, df.where(lit(100) >= Symbol("id")),
Array(sources.IsNotNull("id"), sources.LessThanOrEqual("id", 100)))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,11 @@ trait FileScanSuiteBase extends SharedSparkSession {
val options = new CaseInsensitiveStringMap(ImmutableMap.copyOf(optionsMap))
val optionsNotEqual =
new CaseInsensitiveStringMap(ImmutableMap.copyOf(ImmutableMap.of("key2", "value2")))
val partitionFilters = Seq(And(IsNull('data.int), LessThan('data.int, 0)))
val partitionFiltersNotEqual = Seq(And(IsNull('data.int), LessThan('data.int, 1)))
val dataFilters = Seq(And(IsNull('data.int), LessThan('data.int, 0)))
val dataFiltersNotEqual = Seq(And(IsNull('data.int), LessThan('data.int, 1)))
val partitionFilters = Seq(And(IsNull(Symbol("data").int), LessThan(Symbol("data").int, 0)))
val partitionFiltersNotEqual = Seq(And(IsNull(Symbol("data").int),
LessThan(Symbol("data").int, 1)))
val dataFilters = Seq(And(IsNull(Symbol("data").int), LessThan(Symbol("data").int, 0)))
val dataFiltersNotEqual = Seq(And(IsNull(Symbol("data").int), LessThan(Symbol("data").int, 1)))

scanBuilders.foreach { case (name, scanBuilder, exclusions) =>
test(s"SPARK-33482: Test $name equals") {
Expand Down
8 changes: 4 additions & 4 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
test("inner join where, one match per row") {
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
checkAnswer(
upperCaseData.join(lowerCaseData).where('n === 'N),
upperCaseData.join(lowerCaseData).where(Symbol("n") === 'N),
Seq(
Row(1, "A", 1, "a"),
Row(2, "B", 2, "b"),
Expand Down Expand Up @@ -404,8 +404,8 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan

test("full outer join") {
withTempView("`left`", "`right`") {
upperCaseData.where('N <= 4).createOrReplaceTempView("`left`")
upperCaseData.where('N >= 3).createOrReplaceTempView("`right`")
upperCaseData.where(Symbol("N") <= 4).createOrReplaceTempView("`left`")
upperCaseData.where(Symbol("N") >= 3).createOrReplaceTempView("`right`")

val left = UnresolvedRelation(TableIdentifier("left"))
val right = UnresolvedRelation(TableIdentifier("right"))
Expand Down Expand Up @@ -623,7 +623,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
testData.createOrReplaceTempView("B")
testData2.createOrReplaceTempView("C")
testData3.createOrReplaceTempView("D")
upperCaseData.where('N >= 3).createOrReplaceTempView("`right`")
upperCaseData.where(Symbol("N") >= 3).createOrReplaceTempView("`right`")
val cartesianQueries = Seq(
/** The following should error out since there is no explicit cross join */
"SELECT * FROM testData inner join testData2",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession {

test("SPARK-24709: infers schemas of json strings and pass them to from_json") {
val in = Seq("""{"a": [1, 2, 3]}""").toDS()
val out = in.select(from_json('value, schema_of_json("""{"a": [1]}""")) as "parsed")
val out = in.select(from_json(Symbol("value"), schema_of_json("""{"a": [1]}""")) as "parsed")
val expected = StructType(StructField(
"parsed",
StructType(StructField(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
c: Column => Column,
f: T => U): Unit = {
checkAnswer(
doubleData.select(c('a)),
doubleData.select(c(Symbol("a"))),
(1 to 10).map(n => Row(f((n * 0.2 - 1).asInstanceOf[T])))
)

checkAnswer(
doubleData.select(c('b)),
doubleData.select(c(Symbol("b"))),
(1 to 10).map(n => Row(f((-n * 0.2 + 1).asInstanceOf[T])))
)

Expand All @@ -65,13 +65,13 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
private def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit =
{
checkAnswer(
nnDoubleData.select(c('a)),
nnDoubleData.select(c(Symbol("a"))),
(1 to 10).map(n => Row(f(n * 0.1)))
)

if (f(-1) === StrictMath.log1p(-1)) {
checkAnswer(
nnDoubleData.select(c('b)),
nnDoubleData.select(c(Symbol("b"))),
(1 to 9).map(n => Row(f(n * -0.1))) :+ Row(null)
)
}
Expand All @@ -87,12 +87,12 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
d: (Column, Double) => Column,
f: (Double, Double) => Double): Unit = {
checkAnswer(
nnDoubleData.select(c('a, 'a)),
nnDoubleData.select(c('a, Symbol("a"))),
nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(0))))
)

checkAnswer(
nnDoubleData.select(c('a, 'b)),
nnDoubleData.select(c('a, Symbol("b"))),
nnDoubleData.collect().toSeq.map(r => Row(f(r.getDouble(0), r.getDouble(1))))
)

Expand All @@ -109,7 +109,7 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
val nonNull = nullDoubles.collect().toSeq.filter(r => r.get(0) != null)

checkAnswer(
nullDoubles.select(c('a, 'a)).orderBy('a.asc),
nullDoubles.select(c('a, Symbol("a"))).orderBy(Symbol("a").asc),
Row(null) +: nonNull.map(r => Row(f(r.getDouble(0), r.getDouble(0))))
)
}
Expand Down Expand Up @@ -255,7 +255,7 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
test("factorial") {
val df = (0 to 5).map(i => (i, i)).toDF("a", "b")
checkAnswer(
df.select(factorial('a)),
df.select(factorial(Symbol("a"))),
Seq(Row(1), Row(1), Row(2), Row(6), Row(24), Row(120))
)
checkAnswer(
Expand All @@ -271,11 +271,11 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
test("round/bround/ceil/floor") {
val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a")
checkAnswer(
df.select(round('a), round('a, -1), round('a, -2)),
df.select(round(Symbol("a")), round('a, -1), round('a, -2)),
Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600))
)
checkAnswer(
df.select(bround('a), bround('a, -1), bround('a, -2)),
df.select(bround(Symbol("a")), bround('a, -1), bround('a, -2)),
Seq(Row(5, 0, 0), Row(55, 60, 100), Row(555, 560, 600))
)
checkAnswer(
Expand Down Expand Up @@ -343,11 +343,11 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {
test("round/bround/ceil/floor with data frame from a local Seq of Product") {
val df = spark.createDataFrame(Seq(Tuple1(BigDecimal("5.9")))).toDF("value")
checkAnswer(
df.withColumn("value_rounded", round('value)),
df.withColumn("value_rounded", round(Symbol("value"))),
Seq(Row(BigDecimal("5.9"), BigDecimal("6")))
)
checkAnswer(
df.withColumn("value_brounded", bround('value)),
df.withColumn("value_brounded", bround(Symbol("value"))),
Seq(Row(BigDecimal("5.9"), BigDecimal("6")))
)
checkAnswer(
Expand Down Expand Up @@ -423,10 +423,10 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {

test("hex") {
val data = Seq((28, -28, 100800200404L, "hello")).toDF("a", "b", "c", "d")
checkAnswer(data.select(hex('a)), Seq(Row("1C")))
checkAnswer(data.select(hex('b)), Seq(Row("FFFFFFFFFFFFFFE4")))
checkAnswer(data.select(hex('c)), Seq(Row("177828FED4")))
checkAnswer(data.select(hex('d)), Seq(Row("68656C6C6F")))
checkAnswer(data.select(hex(Symbol("a"))), Seq(Row("1C")))
checkAnswer(data.select(hex(Symbol("b"))), Seq(Row("FFFFFFFFFFFFFFE4")))
checkAnswer(data.select(hex(Symbol("c"))), Seq(Row("177828FED4")))
checkAnswer(data.select(hex(Symbol("d"))), Seq(Row("68656C6C6F")))
checkAnswer(data.selectExpr("hex(a)"), Seq(Row("1C")))
checkAnswer(data.selectExpr("hex(b)"), Seq(Row("FFFFFFFFFFFFFFE4")))
checkAnswer(data.selectExpr("hex(c)"), Seq(Row("177828FED4")))
Expand All @@ -436,8 +436,8 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession {

test("unhex") {
val data = Seq(("1C", "737472696E67")).toDF("a", "b")
checkAnswer(data.select(unhex('a)), Row(Array[Byte](28.toByte)))
checkAnswer(data.select(unhex('b)), Row("string".getBytes(StandardCharsets.UTF_8)))
checkAnswer(data.select(unhex(Symbol("a"))), Row(Array[Byte](28.toByte)))
checkAnswer(data.select(unhex(Symbol("b"))), Row("string".getBytes(StandardCharsets.UTF_8)))
checkAnswer(data.selectExpr("unhex(a)"), Row(Array[Byte](28.toByte)))
checkAnswer(data.selectExpr("unhex(b)"), Row("string".getBytes(StandardCharsets.UTF_8)))
checkAnswer(data.selectExpr("""unhex("##")"""), Row(null))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3066,15 +3066,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
val df = spark.read.format(format).load(dir.getCanonicalPath)
checkPushedFilters(
format,
df.where(('id < 2 and 's.contains("foo")) or ('id > 10 and 's.contains("bar"))),
df.where((Symbol("id") < 2 and Symbol("s").contains("foo")) or
(Symbol("id") > 10 and Symbol("s").contains("bar"))),
Array(sources.Or(sources.LessThan("id", 2), sources.GreaterThan("id", 10))))
checkPushedFilters(
format,
df.where('s.contains("foo") or ('id > 10 and 's.contains("bar"))),
df.where(Symbol("s").contains("foo") or
(Symbol("id") > 10 and Symbol("s").contains("bar"))),
Array.empty)
checkPushedFilters(
format,
df.where('id < 2 and not('id > 10 and 's.contains("bar"))),
df.where(Symbol("id") < 2 and not(Symbol("id") > 10 and Symbol("s").contains("bar"))),
Array(sources.IsNotNull("id"), sources.LessThan("id", 2)))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,9 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
withTable("TBL1", "TBL") {
import org.apache.spark.sql.functions._
val df = spark.range(1000L).select('id,
'id * 2 as "FLD1",
'id * 12 as "FLD2",
lit(null).cast(DoubleType) + 'id as "fld3")
Symbol("id") * 2 as "FLD1",
Symbol("id") * 12 as "FLD2",
lit(null).cast(DoubleType) + Symbol("id") as "fld3")
df.write
.mode(SaveMode.Overwrite)
.bucketBy(10, "id", "FLD1", "FLD2")
Expand Down
Loading