Skip to content

Commit 5b93711

Browse files
committed
Replace LowerCaseSchema with Resolver.
1 parent 6d887db commit 5b93711

File tree

13 files changed

+48
-93
lines changed

13 files changed

+48
-93
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ object SimpleAnalyzer extends Analyzer(EmptyCatalog, EmptyFunctionRegistry, true
3737
class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Boolean)
3838
extends RuleExecutor[LogicalPlan] with HiveTypeCoercion {
3939

40+
val resolver = if (caseSensitive) caseSensitiveResolution else caseInsensitiveResolution
41+
4042
// TODO: pass this in as a parameter.
4143
val fixedPoint = FixedPoint(100)
4244

@@ -48,8 +50,6 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
4850
lazy val batches: Seq[Batch] = Seq(
4951
Batch("MultiInstanceRelations", Once,
5052
NewRelationInstances),
51-
Batch("CaseInsensitiveAttributeReferences", Once,
52-
(if (caseSensitive) Nil else LowercaseAttributeReferences :: Nil) : _*),
5353
Batch("Resolution", fixedPoint,
5454
ResolveReferences ::
5555
ResolveRelations ::
@@ -98,23 +98,6 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
9898
}
9999
}
100100

101-
/**
102-
* Makes attribute naming case insensitive by turning all UnresolvedAttributes to lowercase.
103-
*/
104-
object LowercaseAttributeReferences extends Rule[LogicalPlan] {
105-
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
106-
case UnresolvedRelation(databaseName, name, alias) =>
107-
UnresolvedRelation(databaseName, name, alias.map(_.toLowerCase))
108-
case Subquery(alias, child) => Subquery(alias.toLowerCase, child)
109-
case q: LogicalPlan => q transformExpressions {
110-
case s: Star => s.copy(table = s.table.map(_.toLowerCase))
111-
case UnresolvedAttribute(name) => UnresolvedAttribute(name.toLowerCase)
112-
case Alias(c, name) => Alias(c, name.toLowerCase)()
113-
case GetField(c, name) => GetField(c, name.toLowerCase)
114-
}
115-
}
116-
}
117-
118101
/**
119102
* Replaces [[UnresolvedAttribute]]s with concrete
120103
* [[catalyst.expressions.AttributeReference AttributeReferences]] from a logical plan node's
@@ -127,7 +110,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
127110
q transformExpressions {
128111
case u @ UnresolvedAttribute(name) =>
129112
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
130-
val result = q.resolveChildren(name).getOrElse(u)
113+
val result = q.resolveChildren(name, resolver).getOrElse(u)
131114
logDebug(s"Resolving $u to $result")
132115
result
133116
}
@@ -144,7 +127,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
144127
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
145128
case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved =>
146129
val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
147-
val resolved = unresolved.flatMap(child.resolveChildren)
130+
val resolved = unresolved.flatMap(child.resolve(_, resolver))
148131
val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a })
149132

150133
val missingInProject = requiredAttributes -- p.output
@@ -154,6 +137,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
154137
Sort(ordering,
155138
Project(projectList ++ missingInProject, child)))
156139
} else {
140+
logDebug(s"Failed to find $missingInProject in ${p.output.mkString(", ")}")
157141
s // Nothing we can do here. Return original plan.
158142
}
159143
case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved =>
@@ -165,7 +149,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
165149
)
166150

167151
logDebug(s"Grouping expressions: $groupingRelation")
168-
val resolved = unresolved.flatMap(groupingRelation.resolve)
152+
val resolved = unresolved.flatMap(groupingRelation.resolve(_, resolver))
169153
val missingInAggs = resolved.filterNot(a.outputSet.contains)
170154
logDebug(s"Resolved: $resolved Missing in aggs: $missingInAggs")
171155
if (missingInAggs.nonEmpty) {
@@ -258,22 +242,22 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
258242
case p @ Project(projectList, child) if containsStar(projectList) =>
259243
Project(
260244
projectList.flatMap {
261-
case s: Star => s.expand(child.output)
245+
case s: Star => s.expand(child.output, resolver)
262246
case o => o :: Nil
263247
},
264248
child)
265249
case t: ScriptTransformation if containsStar(t.input) =>
266250
t.copy(
267251
input = t.input.flatMap {
268-
case s: Star => s.expand(t.child.output)
252+
case s: Star => s.expand(t.child.output, resolver)
269253
case o => o :: Nil
270254
}
271255
)
272256
// If the aggregate function argument contains Stars, expand it.
273257
case a: Aggregate if containsStar(a.aggregateExpressions) =>
274258
a.copy(
275259
aggregateExpressions = a.aggregateExpressions.flatMap {
276-
case s: Star => s.expand(a.child.output)
260+
case s: Star => s.expand(a.child.output, resolver)
277261
case o => o :: Nil
278262
}
279263
)
@@ -290,13 +274,11 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
290274
/**
291275
* Removes [[catalyst.plans.logical.Subquery Subquery]] operators from the plan. Subqueries are
292276
* only required to provide scoping information for attributes and can be removed once analysis is
293-
* complete. Similarly, this node also removes
294-
* [[catalyst.plans.logical.LowerCaseSchema LowerCaseSchema]] operators.
277+
* complete.
295278
*/
296279
object EliminateAnalysisOperators extends Rule[LogicalPlan] {
297280
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
298281
case Subquery(_, child) => child
299-
case LowerCaseSchema(child) => child
300282
}
301283
}
302284

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,9 @@ package org.apache.spark.sql.catalyst
2222
* Analysis consists of translating [[UnresolvedAttribute]]s and [[UnresolvedRelation]]s
2323
* into fully typed objects using information in a schema [[Catalog]].
2424
*/
25-
package object analysis
25+
package object analysis {
26+
type Resolver = (String, String) => Boolean
27+
28+
val caseInsensitiveResolution = (a: String, b: String) => a.toLowerCase == b.toLowerCase
29+
val caseSensitiveResolution = (a: String, b: String) => a == b
30+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,12 @@ case class Star(
9898
override def withNullability(newNullability: Boolean) = this
9999
override def withQualifiers(newQualifiers: Seq[String]) = this
100100

101-
def expand(input: Seq[Attribute]): Seq[NamedExpression] = {
101+
def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = {
102102
val expandedAttributes: Seq[Attribute] = table match {
103103
// If there is no table specified, use all input attributes.
104104
case None => input
105105
// If there is a table, pick out attributes that are part of this table.
106-
case Some(t) => input.filter(_.qualifiers contains t)
106+
case Some(t) => input.filter(_.qualifiers.filter(resolver(_,t)).nonEmpty)
107107
}
108108
val mappedAttributes = expandedAttributes.map(mapFunction).zip(input).map {
109109
case (n: NamedExpression, _) => n

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.logical
1919

20+
import org.apache.spark.sql.catalyst.analysis.Resolver
2021
import org.apache.spark.sql.catalyst.errors.TreeNodeException
2122
import org.apache.spark.sql.catalyst.expressions._
2223
import org.apache.spark.sql.catalyst.plans.QueryPlan
@@ -75,19 +76,23 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
7576
* nodes of this LogicalPlan. The attribute is expressed as
7677
* as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
7778
*/
78-
def resolveChildren(name: String): Option[NamedExpression] =
79-
resolve(name, children.flatMap(_.output))
79+
def resolveChildren(name: String, resolver: Resolver): Option[NamedExpression] =
80+
resolve(name, children.flatMap(_.output), resolver)
8081

8182
/**
8283
* Optionally resolves the given string to a [[NamedExpression]] based on the output of this
8384
* LogicalPlan. The attribute is expressed as string in the following form:
8485
* `[scope].AttributeName.[nested].[fields]...`.
8586
*/
86-
def resolve(name: String): Option[NamedExpression] =
87-
resolve(name, output)
87+
def resolve(name: String, resolver: Resolver): Option[NamedExpression] =
88+
resolve(name, output, resolver)
8889

8990
/** Performs attribute resolution given a name and a sequence of possible attributes. */
90-
protected def resolve(name: String, input: Seq[Attribute]): Option[NamedExpression] = {
91+
protected def resolve(
92+
name: String,
93+
input: Seq[Attribute],
94+
resolver: Resolver): Option[NamedExpression] = {
95+
9196
val parts = name.split("\\.")
9297
// Collect all attributes that are output by this nodes children where either the first part
9398
// matches the name or where the first part matches the scope and the second part matches the
@@ -96,16 +101,18 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
96101
val options = input.flatMap { option =>
97102
// If the first part of the desired name matches a qualifier for this possible match, drop it.
98103
val remainingParts =
99-
if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts
100-
if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil
104+
if (option.qualifiers.filter(resolver(_, parts.head)).nonEmpty && parts.size > 1) parts.drop(1) else parts
105+
if (resolver(option.name, remainingParts.head)) (option, remainingParts.tail.toList) :: Nil else Nil
101106
}
102107

103108
options.distinct match {
104109
case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it.
105110
// One match, but we also need to extract the requested nested field.
106111
case Seq((a, nestedFields)) =>
107112
Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
108-
case Seq() => None // No matches.
113+
case Seq() =>
114+
println(s"Could not find $name in ${input.mkString(", ")}")
115+
None // No matches.
109116
case ambiguousReferences =>
110117
throw new TreeNodeException(
111118
this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -154,32 +154,6 @@ case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
154154
override def output = child.output.map(_.withQualifiers(alias :: Nil))
155155
}
156156

157-
/**
158-
* Converts the schema of `child` to all lowercase, together with LowercaseAttributeReferences
159-
* this allows for optional case insensitive attribute resolution. This node can be elided after
160-
* analysis.
161-
*/
162-
case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode {
163-
protected def lowerCaseSchema(dataType: DataType): DataType = dataType match {
164-
case StructType(fields) =>
165-
StructType(fields.map(f =>
166-
StructField(f.name.toLowerCase(), lowerCaseSchema(f.dataType), f.nullable)))
167-
case ArrayType(elemType, containsNull) => ArrayType(lowerCaseSchema(elemType), containsNull)
168-
case otherType => otherType
169-
}
170-
171-
override val output = child.output.map {
172-
case a: AttributeReference =>
173-
AttributeReference(
174-
a.name.toLowerCase,
175-
lowerCaseSchema(a.dataType),
176-
a.nullable)(
177-
a.exprId,
178-
a.qualifiers)
179-
case other => other
180-
}
181-
}
182-
183157
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
184158
extends UnaryNode {
185159

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
246246
* @group userf
247247
*/
248248
def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = {
249-
catalog.registerTable(None, tableName, rdd.queryExecution.analyzed)
249+
catalog.registerTable(None, tableName, rdd.queryExecution.logical)
250250
}
251251

252252
/**

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
380380
}
381381

382382
test("SPARK-3349 partitioning after limit") {
383-
/*
384383
sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC")
385384
.limit(2)
386385
.registerTempTable("subset1")
@@ -395,7 +394,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
395394
sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"),
396395
(1, "a", 1) ::
397396
(2, "b", 2) :: Nil)
398-
*/
399397
}
400398

401399
test("mixed-case keywords") {

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -244,15 +244,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
244244

245245
/* A catalyst metadata catalog that points to the Hive Metastore. */
246246
@transient
247-
override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog {
248-
override def lookupRelation(
249-
databaseName: Option[String],
250-
tableName: String,
251-
alias: Option[String] = None): LogicalPlan = {
252-
253-
LowerCaseSchema(super.lookupRelation(databaseName, tableName, alias))
254-
}
255-
}
247+
override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog
256248

257249
// Note that HiveUDFs will be overridden by functions registered in this context.
258250
@transient

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
129129
// Wait until children are resolved.
130130
case p: LogicalPlan if !p.childrenResolved => p
131131

132-
case p @ InsertIntoTable(
133-
LowerCaseSchema(table: MetastoreRelation), _, child, _) =>
132+
case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) =>
134133
castChildOutput(p, table, child)
135134

136135
case p @ logical.InsertIntoTable(
137-
LowerCaseSchema(
138136
InMemoryRelation(_, _, _,
139-
HiveTableScan(_, table, _))), _, child, _) =>
137+
HiveTableScan(_, table, _)), _, child, _) =>
140138
castChildOutput(p, table, child)
141139
}
142140

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
2424
import org.apache.spark.sql.catalyst.planning._
2525
import org.apache.spark.sql.catalyst.plans._
26-
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LowerCaseSchema}
26+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2727
import org.apache.spark.sql.catalyst.types.StringType
2828
import org.apache.spark.sql.columnar.InMemoryRelation
2929
import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan}
@@ -55,7 +55,7 @@ private[hive] trait HiveStrategies {
5555
object ParquetConversion extends Strategy {
5656
implicit class LogicalPlanHacks(s: SchemaRDD) {
5757
def lowerCase =
58-
new SchemaRDD(s.sqlContext, LowerCaseSchema(s.logicalPlan))
58+
new SchemaRDD(s.sqlContext, s.logicalPlan)
5959

6060
def addPartitioningAttributes(attrs: Seq[Attribute]) =
6161
new SchemaRDD(

0 commit comments

Comments
 (0)