Skip to content

Commit 701dcd2

Browse files
committed
[SPARK-2781][SQL] Check resolution of LogicalPlans in Analyzer.
1 parent 79cdb9b commit 701dcd2

File tree

8 files changed

+103
-20
lines changed

8 files changed

+103
-20
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,12 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
4040
// TODO: pass this in as a parameter.
4141
val fixedPoint = FixedPoint(100)
4242

43-
val batches: Seq[Batch] = Seq(
43+
/**
44+
* Override to provide additional rules for the "Resolution" batch.
45+
*/
46+
val extendedRules: List[Rule[LogicalPlan]] = Nil
47+
48+
lazy val batches: Seq[Batch] = Seq(
4449
Batch("MultiInstanceRelations", Once,
4550
NewRelationInstances),
4651
Batch("CaseInsensitiveAttributeReferences", Once,
@@ -54,23 +59,31 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
5459
StarExpansion ::
5560
ResolveFunctions ::
5661
GlobalAggregates ::
57-
UnresolvedHavingClauseAttributes ::
58-
typeCoercionRules :_*),
62+
UnresolvedHavingClauseAttributes ::
63+
typeCoercionRules :::
64+
extendedRules : _*),
5965
Batch("Check Analysis", Once,
6066
CheckResolution),
6167
Batch("AnalysisOperators", fixedPoint,
6268
EliminateAnalysisOperators)
6369
)
6470

6571
/**
66-
* Makes sure all attributes have been resolved.
72+
* Makes sure all attributes and logical plans have been resolved.
6773
*/
6874
object CheckResolution extends Rule[LogicalPlan] {
6975
def apply(plan: LogicalPlan): LogicalPlan = {
7076
plan.transform {
7177
case p if p.expressions.exists(!_.resolved) =>
7278
throw new TreeNodeException(p,
7379
s"Unresolved attributes: ${p.expressions.filterNot(_.resolved).mkString(",")}")
80+
case p if !p.resolved && p.childrenResolved =>
81+
throw new TreeNodeException(p, "Unresolved plan found")
82+
} match {
83+
// As a backstop, use the root node to check that the entire plan tree is resolved.
84+
case p if !p.resolved =>
85+
throw new TreeNodeException(p, "Unresolved plan in tree")
86+
case p => p
7487
}
7588
}
7689
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,10 @@ trait HiveTypeCoercion {
286286
// If the data type is not boolean and is being cast boolean, turn it into a comparison
287287
// with the numeric value, i.e. x != 0. This will coerce the type into numeric type.
288288
case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0)))
289+
// Stringify boolean if casting to StringType.
290+
// TODO Ensure true/false string letter casing is consistent with Hive in all cases.
291+
case Cast(e, StringType) if e.dataType == BooleanType =>
292+
If(e, Literal("true"), Literal("false"))
289293
// Turn true into 1, and false into 0 if casting boolean into other types.
290294
case Cast(e, dataType) if e.dataType == BooleanType =>
291295
Cast(If(e, Literal(1), Literal(0)), dataType)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
5858

5959
/**
6060
* Returns true if this expression and all its children have been resolved to a specific schema
61-
* and false if it is still contains any unresolved placeholders. Implementations of LogicalPlan
61+
* and false if it still contains any unresolved placeholders. Implementations of LogicalPlan
6262
* can override this (e.g.
6363
* [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]]
6464
* should return `false`).

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,17 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
9393
val e = intercept[TreeNodeException[_]] {
9494
caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation))
9595
}
96-
assert(e.getMessage().toLowerCase.contains("unresolved"))
96+
assert(e.getMessage().toLowerCase.contains("unresolved attribute"))
97+
}
98+
99+
test("throw errors for unresolved plans during analysis") {
100+
case class UnresolvedTestPlan() extends LeafNode {
101+
override lazy val resolved = false
102+
override def output = Nil
103+
}
104+
val e = intercept[TreeNodeException[_]] {
105+
caseSensitiveAnalyze(UnresolvedTestPlan())
106+
}
107+
assert(e.getMessage().toLowerCase.contains("unresolved plan"))
97108
}
98109
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.analysis
1919

2020
import org.scalatest.FunSuite
2121

22+
import org.apache.spark.sql.catalyst.expressions._
23+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
2224
import org.apache.spark.sql.catalyst.types._
2325

2426
class HiveTypeCoercionSuite extends FunSuite {
@@ -84,4 +86,16 @@ class HiveTypeCoercionSuite extends FunSuite {
8486
widenTest(StringType, MapType(IntegerType, StringType, true), None)
8587
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
8688
}
89+
90+
test("boolean casts") {
91+
def ruleTest(initial: Expression, transformed: Expression) {
92+
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
93+
assert(BooleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) ==
94+
Project(Seq(Alias(transformed, "a")()), testRelation))
95+
}
96+
// Remove superflous boolean -> boolean casts.
97+
ruleTest(Cast(Literal(true), BooleanType), Literal(true))
98+
// Stringify boolean when casting to string.
99+
ruleTest(Cast(Literal(false), StringType), If(Literal(false), Literal("true"), Literal("false")))
100+
}
87101
}

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql
1919

20+
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2021
import org.apache.spark.sql.catalyst.expressions._
2122
import org.apache.spark.sql.test._
2223
import org.scalatest.BeforeAndAfterAll
@@ -477,18 +478,48 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
477478
(3, null)))
478479
}
479480

480-
test("EXCEPT") {
481+
test("UNION") {
482+
checkAnswer(
483+
sql("SELECT * FROM lowerCaseData UNION SELECT * FROM upperCaseData"),
484+
(1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") ::
485+
(4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil)
486+
checkAnswer(
487+
sql("SELECT * FROM lowerCaseData UNION SELECT * FROM lowerCaseData"),
488+
(1, "a") :: (2, "b") :: (3, "c") :: (4, "d") :: Nil)
489+
checkAnswer(
490+
sql("SELECT * FROM lowerCaseData UNION ALL SELECT * FROM lowerCaseData"),
491+
(1, "a") :: (1, "a") :: (2, "b") :: (2, "b") :: (3, "c") :: (3, "c") ::
492+
(4, "d") :: (4, "d") :: Nil)
493+
}
494+
495+
test("UNION with column mismatches") {
496+
// Column name mismatches are allowed.
497+
checkAnswer(
498+
sql("SELECT n,l FROM lowerCaseData UNION SELECT N as x1, L as x2 FROM upperCaseData"),
499+
(1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") ::
500+
(4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil)
501+
// Column type mismatches are not allowed, forcing a type coercion.
502+
checkAnswer(
503+
sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"),
504+
("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Tuple1(_)))
505+
// Column type mismatches where a coercion is not possible, in this case between integer
506+
// and array types, trigger a TreeNodeException.
507+
intercept[TreeNodeException[_]] {
508+
sql("SELECT data FROM arrayData UNION SELECT 1 FROM arrayData").collect()
509+
}
510+
}
481511

512+
test("EXCEPT") {
482513
checkAnswer(
483-
sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData "),
514+
sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData"),
484515
(1, "a") ::
485516
(2, "b") ::
486517
(3, "c") ::
487518
(4, "d") :: Nil)
488519
checkAnswer(
489-
sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData "), Nil)
520+
sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData"), Nil)
490521
checkAnswer(
491-
sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData "), Nil)
522+
sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData"), Nil)
492523
}
493524

494525
test("INTERSECT") {
@@ -635,5 +666,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
635666
Seq()
636667
)
637668

669+
test("cast boolean to string") {
670+
// TODO Ensure true/false string letter casing is consistent with Hive in all cases.
671+
checkAnswer(
672+
sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"),
673+
("true", "false") :: Nil)
638674
}
639675
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
262262
/* An analyzer that uses the Hive metastore. */
263263
@transient
264264
override protected[sql] lazy val analyzer =
265-
new Analyzer(catalog, functionRegistry, caseSensitive = false)
265+
new Analyzer(catalog, functionRegistry, caseSensitive = false) {
266+
override val extendedRules = catalog.CreateTables :: catalog.PreInsertionCasts :: Nil
267+
}
266268

267269
/**
268270
* Runs the specified SQL query using Hive.
@@ -353,9 +355,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
353355

354356
/** Extends QueryExecution with hive specific features. */
355357
protected[sql] abstract class QueryExecution extends super.QueryExecution {
356-
// TODO: Create mixin for the analyzer instead of overriding things here.
357-
override lazy val optimizedPlan =
358-
optimizer(ExtractPythonUdfs(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))))
358+
// TODO: Utilize extendedRules in the analyzer instead of overriding things here.
359+
override lazy val optimizedPlan = optimizer(ExtractPythonUdfs(analyzed))
359360

360361
override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy())
361362

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,17 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
109109
*/
110110
object CreateTables extends Rule[LogicalPlan] {
111111
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
112+
// Wait until children are resolved.
113+
case p: LogicalPlan if !p.childrenResolved => p
114+
112115
case InsertIntoCreatedTable(db, tableName, child) =>
113116
val (dbName, tblName) = processDatabaseAndTableName(db, tableName)
114117
val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase)
115118

116119
createTable(databaseName, tblName, child.output)
117120

118121
InsertIntoTable(
119-
EliminateAnalysisOperators(
120-
lookupRelation(Some(databaseName), tblName, None)),
122+
lookupRelation(Some(databaseName), tblName, None),
121123
Map.empty,
122124
child,
123125
overwrite = false)
@@ -130,15 +132,17 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
130132
*/
131133
object PreInsertionCasts extends Rule[LogicalPlan] {
132134
def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
133-
// Wait until children are resolved
135+
// Wait until children are resolved.
134136
case p: LogicalPlan if !p.childrenResolved => p
135137

136-
case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) =>
138+
case p @ InsertIntoTable(
139+
LowerCaseSchema(table: MetastoreRelation), _, child, _) =>
137140
castChildOutput(p, table, child)
138141

139142
case p @ logical.InsertIntoTable(
140-
InMemoryRelation(_, _, _,
141-
HiveTableScan(_, table, _)), _, child, _) =>
143+
LowerCaseSchema(
144+
InMemoryRelation(_, _, _,
145+
HiveTableScan(_, table, _))), _, child, _) =>
142146
castChildOutput(p, table, child)
143147
}
144148

0 commit comments

Comments
 (0)