diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 795ef0dbc4c47..93f53eabfbab8 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -84,7 +84,7 @@ def __init__(self, sparkContext, sqlContext=None): >>> df.registerTempTable("allTypes") >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() - [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] + [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), row.a=1)] >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, ... x.row.a, x.list)).collect() [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index c363a5efacde8..54ab13ca352d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -385,7 +385,7 @@ class SqlParser extends AbstractSparkSQLParser { protected lazy val dotExpressionHeader: Parser[Expression] = (ident <~ ".") ~ ident ~ rep("." ~> ident) ^^ { - case i1 ~ i2 ~ rest => UnresolvedAttribute(i1 + "." + i2 + rest.mkString(".", ".", "")) + case i1 ~ i2 ~ rest => UnresolvedAttribute((Seq(i1, i2) ++ rest).mkString(".")) } protected lazy val dataType: Parser[DataType] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e4e542562f22d..3133d0fefcf61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -268,7 +268,9 @@ class Analyzer(catalog: Catalog, logDebug(s"Resolving $u to $result") result case UnresolvedGetField(child, fieldName) if child.resolved => - resolveGetField(child, fieldName) + val result = q.resolveGetField(child, fieldName, resolver) + logDebug(s"Resolving $fieldName of $child to $result") + result } } @@ -277,36 +279,6 @@ class Analyzer(catalog: Catalog, */ protected def containsStar(exprs: Seq[Expression]): Boolean = exprs.exists(_.collect { case _: Star => true }.nonEmpty) - - /** - * Returns the resolved `GetField`, and report error if no desired field or over one - * desired fields are found. - */ - protected def resolveGetField(expr: Expression, fieldName: String): Expression = { - def findField(fields: Array[StructField]): Int = { - val checkField = (f: StructField) => resolver(f.name, fieldName) - val ordinal = fields.indexWhere(checkField) - if (ordinal == -1) { - throw new AnalysisException( - s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") - } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { - throw new AnalysisException( - s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") - } else { - ordinal - } - } - expr.dataType match { - case StructType(fields) => - val ordinal = findField(fields) - StructGetField(expr, fields(ordinal), ordinal) - case ArrayType(StructType(fields), containsNull) => - val ordinal = findField(fields) - ArrayGetField(expr, fields(ordinal), ordinal, containsNull) - case otherType => - throw new AnalysisException(s"GetField is not valid on fields of type $otherType") - } - } } /** @@ -320,8 +292,7 @@ class Analyzer(catalog: Catalog, case s @ Sort(ordering, global, p @ Project(projectList, child)) if !s.resolved && p.resolved => val unresolved = ordering.flatMap(_.collect { case UnresolvedAttribute(name) => name }) - val resolved = unresolved.flatMap(child.resolve(_, resolver)) - val requiredAttributes = AttributeSet(resolved.collect { case a: Attribute => a }) + val requiredAttributes = AttributeSet(unresolved.flatMap(child.resolve(_, resolver))) val missingInProject = requiredAttributes -- p.output if (missingInProject.nonEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8c4f09b58a4f2..aea5901af7c22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, Resolver} +import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.types.{ArrayType, StructType, StructField} abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { @@ -192,14 +193,17 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => - // The foldLeft adds UnresolvedGetField for every remaining parts of the name, - // and aliased it with the last part of the name. - // For example, consider name "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias - // the final expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)(UnresolvedGetField) - val aliasName = nestedFields.last - Some(Alias(fieldExprs, aliasName)()) + // The foldLeft will resolve all of the nested data type, to get its attributes. + val fieldExprs = nestedFields.foldLeft(a: Expression) { case (e, fieldName) => + resolveGetField(e, fieldName, resolver) + } + + // TODO the alias name is quite tricky to me, set it to _col1, _col2.. ? + // Set it as original attribute name like "a.b.c" seems still confusing, + // and we may never reference this column by its name (with "."), except + // people write SQL like: SELECT a.b.c as newCol FROM nestedTable, which + // explicitly specifying the alias name for the output column + Some(Alias(fieldExprs, name)()) // No matches. case Seq() => @@ -212,6 +216,36 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}") } } + + /** + * Returns the resolved `GetField`, and report error if no desired field or over one + * desired fields are found. + */ + def resolveGetField(expr: Expression, fieldName: String, resolver: Resolver): Expression = { + def findField(fields: Array[StructField]): Int = { + val checkField = (f: StructField) => resolver(f.name, fieldName) + val ordinal = fields.indexWhere(checkField) + if (ordinal == -1) { + throw new AnalysisException( + s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}") + } else if (fields.indexWhere(checkField, ordinal + 1) != -1) { + throw new AnalysisException( + s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}") + } else { + ordinal + } + } + expr.dataType match { + case StructType(fields) => + val ordinal = findField(fields) + StructGetField(expr, fields(ordinal), ordinal) + case ArrayType(StructType(fields), containsNull) => + val ordinal = findField(fields) + ArrayGetField(expr, fields(ordinal), ordinal, containsNull) + case otherType => + throw new AnalysisException(s"GetField is not valid on fields of type $otherType") + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 097bf0dd23c89..06845e8b109fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -43,6 +43,22 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { ) } + test("SPARK-6145 order by the nested data #1") { + sqlCtx.jsonRDD(sqlCtx.sparkContext.parallelize( + """{"a": {"b": {"d": 1}}, "c": 1}""" :: Nil)).registerTempTable("nestedOrder") + + checkAnswer(sqlCtx.sql("SELECT 1 FROM nestedOrder ORDER BY c"), Row(1)) + checkAnswer(sqlCtx.sql("SELECT 1 FROM nestedOrder ORDER BY a.b.d"), Row(1)) + checkAnswer(sqlCtx.sql("SELECT a.b.d FROM nestedOrder ORDER BY a.b.d"), Row(1)) + } + + test("SPARK-6145 order by the nested data #2") { + sqlCtx.jsonRDD(sqlCtx.sparkContext.parallelize( + """{"a": {"a": {"a": 1}}, "c": 1}""" :: Nil)).registerTempTable("nestedOrder") + + checkAnswer(sqlCtx.sql("SELECT a.a.a FROM nestedOrder ORDER BY a.a.a"), Row(1)) + } + test("grouping on nested fields") { jsonRDD(sparkContext.parallelize("""{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") @@ -52,7 +68,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { """ |select attribute, sum(cnt) |from ( - | select nested.attribute, count(*) as cnt + | select nested.attribute as attribute, count(*) as cnt | from rows | group by nested.attribute) a |group by attribute diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index f4440e5b7846a..ce21d27210710 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -80,7 +80,7 @@ class HiveResolutionSuite extends HiveComparisonTest { .toDF().registerTempTable("caseSensitivityTest") val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") - assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "a", "b", "A", "B"), + assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "n.a", "n.b", "n.A", "n.B"), "The output schema did not preserve the case of the query.") query.collect() } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 22ea19bd82f86..c8559f287b6f6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -48,6 +48,15 @@ class SQLQuerySuite extends QueryTest { Row(1) :: Row(2) :: Row(3) :: Nil) } + test("SPARK-6145 insert into table by selecting data from a nested table") { + jsonRDD(sparkContext.parallelize( + """{"a": {"a": {"a": 1}}, "c": 1}""" :: Nil)).registerTempTable("nestedOrder") + + sql("CREATE TABLE gen_tmp_6145 (key Int)") + sql("INSERT INTO table gen_tmp_6145 SELECT a.a.a from nestedOrder") + sql("DROP TABLE gen_tmp_6145") + } + test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") { checkAnswer( sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"),