Skip to content

Commit 1d70c4f

Browse files
committed
[SPARK-2866][SQL] Support attributes in ORDER BY that aren't in SELECT
Minor refactoring to allow resolution either using a nodes input or output. Author: Michael Armbrust <[email protected]> Closes #1795 from marmbrus/ordering and squashes the following commits: 237f580 [Michael Armbrust] style 74d833b [Michael Armbrust] newline 705d963 [Michael Armbrust] Add a rule for resolving ORDER BY expressions that reference attributes not present in the SELECT clause. 82cabda [Michael Armbrust] Generalize attribute resolution.
1 parent 69ec678 commit 1d70c4f

File tree

3 files changed

+116
-7
lines changed

3 files changed

+116
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
4848
Batch("Resolution", fixedPoint,
4949
ResolveReferences ::
5050
ResolveRelations ::
51+
ResolveSortReferences ::
5152
NewRelationInstances ::
5253
ImplicitGenerate ::
5354
StarExpansion ::
@@ -113,13 +114,58 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
113114
q transformExpressions {
114115
case u @ UnresolvedAttribute(name) =>
115116
// Leave unchanged if resolution fails. Hopefully will be resolved next round.
116-
val result = q.resolve(name).getOrElse(u)
117+
val result = q.resolveChildren(name).getOrElse(u)
117118
logDebug(s"Resolving $u to $result")
118119
result
119120
}
120121
}
121122
}
122123

124+
/**
125+
* In many dialects of SQL is it valid to sort by attributes that are not present in the SELECT
126+
* clause. This rule detects such queries and adds the required attributes to the original
127+
* projection, so that they will be available during sorting. Another projection is added to
128+
* remove these attributes after sorting.
129+
*/
130+
object ResolveSortReferences extends Rule[LogicalPlan] {
131+
def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
132+
case s @ Sort(ordering, p @ Project(projectList, child)) if !s.resolved && p.resolved =>
133+
val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
134+
val resolved = unresolved.flatMap(child.resolveChildren)
135+
val requiredAttributes = resolved.collect { case a: Attribute => a }.toSet
136+
137+
val missingInProject = requiredAttributes -- p.output
138+
if (missingInProject.nonEmpty) {
139+
// Add missing attributes and then project them away after the sort.
140+
Project(projectList,
141+
Sort(ordering,
142+
Project(projectList ++ missingInProject, child)))
143+
} else {
144+
s // Nothing we can do here. Return original plan.
145+
}
146+
case s @ Sort(ordering, a @ Aggregate(grouping, aggs, child)) if !s.resolved && a.resolved =>
147+
val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name })
148+
// A small hack to create an object that will allow us to resolve any references that
149+
// refer to named expressions that are present in the grouping expressions.
150+
val groupingRelation = LocalRelation(
151+
grouping.collect { case ne: NamedExpression => ne.toAttribute }
152+
)
153+
154+
logWarning(s"Grouping expressions: $groupingRelation")
155+
val resolved = unresolved.flatMap(groupingRelation.resolve).toSet
156+
val missingInAggs = resolved -- a.outputSet
157+
logWarning(s"Resolved: $resolved Missing in aggs: $missingInAggs")
158+
if (missingInAggs.nonEmpty) {
159+
// Add missing grouping exprs and then project them away after the sort.
160+
Project(a.output,
161+
Sort(ordering,
162+
Aggregate(grouping, aggs ++ missingInAggs, child)))
163+
} else {
164+
s // Nothing we can do here. Return original plan.
165+
}
166+
}
167+
}
168+
123169
/**
124170
* Replaces [[UnresolvedFunction]]s with concrete [[catalyst.expressions.Expression Expressions]].
125171
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,32 +72,45 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
7272
def childrenResolved: Boolean = !children.exists(!_.resolved)
7373

7474
/**
75-
* Optionally resolves the given string to a [[NamedExpression]]. The attribute is expressed as
75+
* Optionally resolves the given string to a [[NamedExpression]] using the input from all child
76+
* nodes of this LogicalPlan. The attribute is expressed as
7677
* as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
7778
*/
78-
def resolve(name: String): Option[NamedExpression] = {
79+
def resolveChildren(name: String): Option[NamedExpression] =
80+
resolve(name, children.flatMap(_.output))
81+
82+
/**
83+
* Optionally resolves the given string to a [[NamedExpression]] based on the output of this
84+
* LogicalPlan. The attribute is expressed as string in the following form:
85+
* `[scope].AttributeName.[nested].[fields]...`.
86+
*/
87+
def resolve(name: String): Option[NamedExpression] =
88+
resolve(name, output)
89+
90+
/** Performs attribute resolution given a name and a sequence of possible attributes. */
91+
protected def resolve(name: String, input: Seq[Attribute]): Option[NamedExpression] = {
7992
val parts = name.split("\\.")
8093
// Collect all attributes that are output by this nodes children where either the first part
8194
// matches the name or where the first part matches the scope and the second part matches the
8295
// name. Return these matches along with any remaining parts, which represent dotted access to
8396
// struct fields.
84-
val options = children.flatMap(_.output).flatMap { option =>
97+
val options = input.flatMap { option =>
8598
// If the first part of the desired name matches a qualifier for this possible match, drop it.
8699
val remainingParts =
87100
if (option.qualifiers.contains(parts.head) && parts.size > 1) parts.drop(1) else parts
88101
if (option.name == remainingParts.head) (option, remainingParts.tail.toList) :: Nil else Nil
89102
}
90103

91104
options.distinct match {
92-
case (a, Nil) :: Nil => Some(a) // One match, no nested fields, use it.
105+
case Seq((a, Nil)) => Some(a) // One match, no nested fields, use it.
93106
// One match, but we also need to extract the requested nested field.
94-
case (a, nestedFields) :: Nil =>
107+
case Seq((a, nestedFields)) =>
95108
a.dataType match {
96109
case StructType(fields) =>
97110
Some(Alias(nestedFields.foldLeft(a: Expression)(GetField), nestedFields.last)())
98111
case _ => None // Don't know how to resolve these field references
99112
}
100-
case Nil => None // No matches.
113+
case Seq() => None // No matches.
101114
case ambiguousReferences =>
102115
throw new TreeNodeException(
103116
this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.hive.execution
19+
20+
import scala.reflect.ClassTag
21+
22+
import org.apache.spark.sql.{SQLConf, QueryTest}
23+
import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin}
24+
import org.apache.spark.sql.hive.test.TestHive
25+
import org.apache.spark.sql.hive.test.TestHive._
26+
27+
/**
28+
* A collection of hive query tests where we generate the answers ourselves instead of depending on
29+
* Hive to generate them (in contrast to HiveQuerySuite). Often this is because the query is
30+
* valid, but Hive currently cannot execute it.
31+
*/
32+
class SQLQuerySuite extends QueryTest {
33+
test("ordering not in select") {
34+
checkAnswer(
35+
sql("SELECT key FROM src ORDER BY value"),
36+
sql("SELECT key FROM (SELECT key, value FROM src ORDER BY value) a").collect().toSeq)
37+
}
38+
39+
test("ordering not in agg") {
40+
checkAnswer(
41+
sql("SELECT key FROM src GROUP BY key, value ORDER BY value"),
42+
sql("""
43+
SELECT key
44+
FROM (
45+
SELECT key, value
46+
FROM src
47+
GROUP BY key, value
48+
ORDER BY value) a""").collect().toSeq)
49+
}
50+
}

0 commit comments

Comments
 (0)