-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-16947][SQL] Support type coercion and foldable expression for inline tables #14676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
d7acae5
6a0450b
2327b79
fcc3caf
f597bae
092605b
4723902
c1071af
1b39a97
2e68438
fb9de34
aed7c5e
08f4e39
285b941
88e7272
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.analysis | ||
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.Cast | ||
| import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} | ||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.types.{StructField, StructType} | ||
|
|
||
| /** | ||
| * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. | ||
| */ | ||
| object ResolveInlineTables extends Rule[LogicalPlan] { | ||
| override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { | ||
| case table: UnresolvedInlineTable if table.expressionsResolved => | ||
| validateInputDimension(table) | ||
| validateInputFoldable(table) | ||
| convert(table) | ||
| } | ||
|
|
||
| /** | ||
| * Validates that all inline table data are foldable expressions. | ||
| * | ||
| * This is publicly visible for unit testing. | ||
| */ | ||
| def validateInputFoldable(table: UnresolvedInlineTable): Unit = { | ||
| table.rows.foreach { row => | ||
| row.foreach { e => | ||
| if (!e.resolved || !e.foldable) { | ||
| e.failAnalysis(s"cannot evaluate expression ${e.sql} in inline table definition") | ||
|
||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Validates the input data dimension: | ||
| * 1. All rows have the same cardinality. | ||
| * 2. The number of column aliases defined is consistent with the number of columns in data. | ||
| * | ||
| * This is publicly visible for unit testing. | ||
| */ | ||
| def validateInputDimension(table: UnresolvedInlineTable): Unit = { | ||
| if (table.rows.nonEmpty) { | ||
| val numCols = table.rows.head.size | ||
| table.rows.zipWithIndex.foreach { case (row, ri) => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we just get the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good idea. Let me do that. |
||
| if (row.size != numCols) { | ||
| table.failAnalysis(s"expected $numCols columns but found ${row.size} columns in row $ri") | ||
| } | ||
| } | ||
|
|
||
| if (table.names.size != numCols) { | ||
| table.failAnalysis(s"expected ${table.names.size} columns but found $numCols in first row") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Convert a valid (with right shape and foldable inputs) [[UnresolvedInlineTable]] | ||
| * into a [[LocalRelation]]. | ||
| * | ||
| * This function attempts to coerce inputs into consistent types. | ||
| * | ||
| * This is publicly visible for unit testing. | ||
| */ | ||
| def convert(table: UnresolvedInlineTable): LocalRelation = { | ||
| val numCols = table.rows.head.size | ||
|
|
||
| // For each column, traverse all the values and find a common data type. | ||
| val targetTypes = Seq.tabulate(numCols) { ci => | ||
|
||
| val inputTypes = table.rows.map(_(ci).dataType) | ||
| TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse { | ||
|
||
| table.failAnalysis(s"incompatible types found in column $ci for inline table") | ||
| } | ||
| } | ||
| assert(targetTypes.size == table.names.size) | ||
|
||
|
|
||
| val newRows: Seq[InternalRow] = table.rows.map { row => | ||
| InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) => | ||
| val targetType = targetTypes(ci) | ||
| if (e.dataType.sameType(targetType)) { | ||
| e.eval() | ||
|
||
| } else { | ||
| Cast(e, targetType).eval() | ||
| } | ||
| }) | ||
| } | ||
|
|
||
| val attributes = StructType(targetTypes.zip(table.names) | ||
| .map { case (typ, name) => StructField(name, typ) }).toAttributes | ||
|
||
| LocalRelation(attributes, newRows) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -49,6 +49,23 @@ case class UnresolvedRelation( | |
| override lazy val resolved = false | ||
| } | ||
|
|
||
| /** | ||
| * An inline table that has not been resolved yet. Once resolved, it is turned by the analyzer into | ||
| * a [[org.apache.spark.sql.catalyst.plans.logical.LocalRelation]]. | ||
| * | ||
| * @param names list of column names | ||
| * @param rows expressions for the data | ||
| */ | ||
| case class UnresolvedInlineTable( | ||
| names: Seq[String], | ||
| rows: Seq[Seq[Expression]]) | ||
| extends LeafNode { | ||
|
|
||
| lazy val expressionsResolved: Boolean = rows.forall(_.forall(_.resolved)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is used only once. Lets move this code into that location.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do want this memoized, so a lazy val is better here. |
||
| override def output: Seq[Attribute] = Nil | ||
| override lazy val resolved = false | ||
| } | ||
|
|
||
| /** | ||
| * Holds the name of an attribute that has yet to be resolved. | ||
| */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.analysis | ||
|
|
||
| import org.scalatest.BeforeAndAfter | ||
|
|
||
| import org.apache.spark.sql.AnalysisException | ||
| import org.apache.spark.sql.catalyst.expressions.{Literal, Rand} | ||
| import org.apache.spark.sql.catalyst.plans.PlanTest | ||
| import org.apache.spark.sql.types.LongType | ||
|
|
||
| /** | ||
| * Unit tests for [[ResolveInlineTables]]. Note that there are also test cases defined in | ||
| * end-to-end tests (in sql/core module) for verifying the correct error messages are shown | ||
| * in negative cases. | ||
| */ | ||
| class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter { | ||
|
|
||
| private def lit(v: Any): Literal = Literal(v) | ||
|
|
||
| test("validate inputs are foldable") { | ||
| ResolveInlineTables.validateInputFoldable( | ||
| UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1))))) | ||
|
|
||
| // nondeterministic (rand) | ||
| intercept[AnalysisException] { | ||
| ResolveInlineTables.validateInputFoldable( | ||
| UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(Rand(1))))) | ||
| } | ||
|
|
||
| // unresolved attribute | ||
| intercept[AnalysisException] { | ||
| ResolveInlineTables.validateInputFoldable( | ||
| UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))) | ||
|
||
| } | ||
| } | ||
|
|
||
| test("validate input dimensions") { | ||
| ResolveInlineTables.validateInputDimension( | ||
| UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2))))) | ||
|
|
||
| // num alias != data dimension | ||
| intercept[AnalysisException] { | ||
| ResolveInlineTables.validateInputDimension( | ||
| UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2))))) | ||
| } | ||
|
|
||
| // num alias == data dimension, but data themselves are inconsistent | ||
| intercept[AnalysisException] { | ||
| ResolveInlineTables.validateInputDimension( | ||
| UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22))))) | ||
| } | ||
| } | ||
|
|
||
| test("do not fire the rule if not all expressions are resolved") { | ||
| val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A")))) | ||
| assert(ResolveInlineTables(table) == table) | ||
| } | ||
|
|
||
| test("convert") { | ||
| val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) | ||
| val converted = ResolveInlineTables.convert(table) | ||
|
|
||
| assert(converted.output.map(_.dataType) == Seq(LongType)) | ||
| assert(converted.data.size == 2) | ||
| assert(converted.data(0).getLong(0) == 1L) | ||
| assert(converted.data(1).getLong(0) == 2L) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
|
|
||
| -- single row, without table and column alias | ||
| select * from values ("one", 1); | ||
|
|
||
| -- single row, without column alias | ||
| select * from values ("one", 1) as data; | ||
|
|
||
| -- single row | ||
| select * from values ("one", 1) as data(a, b); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a case for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||
|
|
||
| -- two rows | ||
|
||
| select * from values ("one", 1), ("two", 2) as data(a, b); | ||
|
|
||
| -- int and long coercion | ||
| select * from values ("one", 1), ("two", 2L) as data(a, b); | ||
|
|
||
| -- foldable expressions | ||
| select * from values ("one", 1 + 0), ("two", 1 + 3L) as data(a, b); | ||
|
|
||
| -- complex types | ||
| select * from values ("one", array(0, 1)), ("two", array(2, 3)) as data(a, b); | ||
|
|
||
| -- decimal and double coercion | ||
| select * from values ("one", 2.0), ("two", 3.0D) as data(a, b); | ||
|
|
||
| -- error reporting: different number of columns | ||
| select * from values ("one", 2.0), ("two") as data(a, b); | ||
|
|
||
| -- error reporting: types that are incompatible | ||
| select * from values ("one", array(0, 1)), ("two", struct(1, 2)) as data(a, b); | ||
|
|
||
| -- error reporting: number aliases different from number data values | ||
| select * from values ("one"), ("two") as data(a, b); | ||
|
|
||
| -- error reporting: unresolved expression | ||
| select * from values ("one", random_not_exist_func(1)), ("two", 2) as data(a, b); | ||
|
|
||
| -- error reporting: aggregate expression | ||
| select * from values ("one", count(1)), ("two", 2) as data(a, b); | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test is located in the same package, so you could reduce visibility to protected/package.