-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-17069] Expose spark.range() as table-valued function in SQL #14656
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5d7fbfb
e5b0e27
0d4ad5e
4d9eb53
21078e6
cc987d3
92e668f
d1fd2b9
fbf515f
29f9538
e9251f5
1f81b97
94ad7f1
a444b9e
7bac3af
78c9f05
4d94ac0
2f80f54
8e03f51
1fa57c4
f8831ca
7ebd563
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.analysis | ||
|
|
||
| import org.apache.spark.{SparkConf, SparkContext} | ||
| import org.apache.spark.sql.catalyst.expressions.Expression | ||
| import org.apache.spark.sql.catalyst.plans._ | ||
| import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} | ||
| import org.apache.spark.sql.catalyst.rules._ | ||
| import org.apache.spark.sql.types.{DataType, IntegerType, LongType} | ||
|
|
||
| /** | ||
| * Rule that resolves table-valued function references. | ||
| */ | ||
| object ResolveTableValuedFunctions extends Rule[LogicalPlan] { | ||
| private lazy val defaultParallelism = | ||
| SparkContext.getOrCreate(new SparkConf(false)).defaultParallelism | ||
|
|
||
| /** | ||
| * List of argument names and their types, used to declare a function. | ||
| */ | ||
| private case class ArgumentList(args: (String, DataType)*) { | ||
| /** | ||
| * Try to cast the expressions to satisfy the expected types of this argument list. If there | ||
| * are any types that cannot be casted, then None is returned. | ||
| */ | ||
| def implicitCast(values: Seq[Expression]): Option[Seq[Expression]] = { | ||
| if (args.length == values.length) { | ||
| val casted = values.zip(args).map { case (value, (_, expectedType)) => | ||
| TypeCoercion.ImplicitTypeCasts.implicitCast(value, expectedType) | ||
| } | ||
| if (casted.forall(_.isDefined)) { | ||
| return Some(casted.map(_.get)) | ||
| } | ||
| } | ||
| None | ||
| } | ||
|
|
||
| override def toString: String = { | ||
| args.map { a => | ||
| s"${a._1}: ${a._2.typeName}" | ||
| }.mkString(", ") | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * A TVF maps argument lists to resolver functions that accept those arguments. Using a map | ||
| * here allows for function overloading. | ||
| */ | ||
| private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan] | ||
|
|
||
| /** | ||
| * TVF builder. | ||
| */ | ||
| private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], LogicalPlan]) | ||
| : (ArgumentList, Seq[Any] => LogicalPlan) = { | ||
| (ArgumentList(args: _*), | ||
| pf orElse { | ||
| case args => | ||
| throw new IllegalArgumentException( | ||
| "Invalid arguments for resolved function: " + args.mkString(", ")) | ||
| }) | ||
| } | ||
|
|
||
| /** | ||
| * Internal registry of table-valued functions. | ||
| */ | ||
| private val builtinFunctions: Map[String, TVF] = Map( | ||
| "range" -> Map( | ||
| /* range(end) */ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could make this a bit more concise by using a combination of a builder and partial function. For example: // Builder
def tvf(params: (String, DataType)*)(pf: PartialFunction[Seq[Expression], LogicalPlan]): TVF = (ArgumentList(params: _*), pf)
// Use
private val builtinFunctions: Map[String, TVF] = Map(
"range" -> Map(
/* range(end) */
tvf("end" -> LongType) { case Seq(end: Long) =>
Range(0, end, 1, defaultParallelism)
},
/* range(start, end) */
tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) =>
Range(start, end, 1, defaultParallelism)
}
/* ... */)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That seems nice. Updated. |
||
| tvf("end" -> LongType) { case Seq(end: Long) => | ||
| Range(0, end, 1, defaultParallelism) | ||
| }, | ||
|
|
||
| /* range(start, end) */ | ||
| tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) => | ||
| Range(start, end, 1, defaultParallelism) | ||
| }, | ||
|
|
||
| /* range(start, end, step) */ | ||
| tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) { | ||
| case Seq(start: Long, end: Long, step: Long) => | ||
| Range(start, end, step, defaultParallelism) | ||
| }, | ||
|
|
||
| /* range(start, end, step, numPartitions) */ | ||
| tvf("start" -> LongType, "end" -> LongType, "step" -> LongType, | ||
| "numPartitions" -> IntegerType) { | ||
| case Seq(start: Long, end: Long, step: Long, numPartitions: Int) => | ||
| Range(start, end, step, numPartitions) | ||
| }) | ||
| ) | ||
|
|
||
| override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { | ||
| case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => | ||
| builtinFunctions.get(u.functionName) match { | ||
| case Some(tvf) => | ||
| val resolved = tvf.flatMap { case (argList, resolver) => | ||
| argList.implicitCast(u.functionArgs) match { | ||
| case Some(casted) => | ||
| Some(resolver(casted.map(_.eval()))) | ||
| case _ => | ||
| None | ||
| } | ||
| } | ||
| resolved.headOption.getOrElse { | ||
| val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ") | ||
| u.failAnalysis( | ||
| s"""error: table-valued function ${u.functionName} with alternatives: | ||
| |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} | ||
| |cannot be applied to: (${argTypes})""".stripMargin) | ||
| } | ||
| case _ => | ||
| u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function") | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -657,6 +657,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { | |
| table.optionalMap(ctx.sample)(withSample) | ||
| } | ||
|
|
||
| /** | ||
| * Create a table-valued function call with arguments, e.g. range(1000) | ||
| */ | ||
| override def visitTableValuedFunction(ctx: TableValuedFunctionContext) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit style I think we need to put the every arguments on a separate line if the unbroken line exceeds 100 characters.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a fine style actually.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nevermind then :) |
||
| : LogicalPlan = withOrigin(ctx) { | ||
| UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression)) | ||
| } | ||
|
|
||
| /** | ||
| * Create an inline table (a virtual table in Hive parlance). | ||
| */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| -- unresolved function | ||
| select * from dummy(3); | ||
|
|
||
| -- range call with end | ||
| select * from range(6 + cos(3)); | ||
|
|
||
| -- range call with start and end | ||
| select * from range(5, 10); | ||
|
|
||
| -- range call with step | ||
| select * from range(0, 10, 2); | ||
|
|
||
| -- range call with numPartitions | ||
| select * from range(0, 10, 1, 200); | ||
|
|
||
| -- range call error | ||
| select * from range(1, 1, 1, 1, 1); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you also test nulls?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
|
||
| -- range call with null | ||
| select * from range(1, null); | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| -- Automatically generated by SQLQueryTestSuite | ||
| -- Number of queries: 7 | ||
|
|
||
|
|
||
| -- !query 0 | ||
| select * from dummy(3) | ||
| -- !query 0 schema | ||
| struct<> | ||
| -- !query 0 output | ||
| org.apache.spark.sql.AnalysisException | ||
| could not resolve `dummy` to a table-valued function; line 1 pos 14 | ||
|
|
||
|
|
||
| -- !query 1 | ||
| select * from range(6 + cos(3)) | ||
| -- !query 1 schema | ||
| struct<id:bigint> | ||
| -- !query 1 output | ||
| 0 | ||
| 1 | ||
| 2 | ||
| 3 | ||
| 4 | ||
|
|
||
|
|
||
| -- !query 2 | ||
| select * from range(5, 10) | ||
| -- !query 2 schema | ||
| struct<id:bigint> | ||
| -- !query 2 output | ||
| 5 | ||
| 6 | ||
| 7 | ||
| 8 | ||
| 9 | ||
|
|
||
|
|
||
| -- !query 3 | ||
| select * from range(0, 10, 2) | ||
| -- !query 3 schema | ||
| struct<id:bigint> | ||
| -- !query 3 output | ||
| 0 | ||
| 2 | ||
| 4 | ||
| 6 | ||
| 8 | ||
|
|
||
|
|
||
| -- !query 4 | ||
| select * from range(0, 10, 1, 200) | ||
| -- !query 4 schema | ||
| struct<id:bigint> | ||
| -- !query 4 output | ||
| 0 | ||
| 1 | ||
| 2 | ||
| 3 | ||
| 4 | ||
| 5 | ||
| 6 | ||
| 7 | ||
| 8 | ||
| 9 | ||
|
|
||
|
|
||
| -- !query 5 | ||
| select * from range(1, 1, 1, 1, 1) | ||
| -- !query 5 schema | ||
| struct<> | ||
| -- !query 5 output | ||
| org.apache.spark.sql.AnalysisException | ||
| error: table-valued function range with alternatives: | ||
| (end: long) | ||
| (start: long, end: long) | ||
| (start: long, end: long, step: long) | ||
| (start: long, end: long, step: long, numPartitions: integer) | ||
| cannot be applied to: (integer, integer, integer, integer, integer); line 1 pos 14 | ||
|
|
||
|
|
||
| -- !query 6 | ||
| select * from range(1, null) | ||
| -- !query 6 schema | ||
| struct<> | ||
| -- !query 6 output | ||
| java.lang.IllegalArgumentException | ||
| Invalid arguments for resolved function: 1, null |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for doing this in a separate file