From 5d7fbfb30163b15cbbe832c755a4774c7c9e3b91 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 14 Aug 2016 15:38:40 -0700 Subject: [PATCH 01/21] Sun Aug 14 15:38:40 PDT 2016 --- .../spark/sql/catalyst/parser/SqlBase.g4 | 1 + .../sql/catalyst/analysis/Analyzer.scala | 1 + .../analysis/TableValuedFunctions.scala | 52 +++++++++++++++++++ .../sql/catalyst/analysis/unresolved.scala | 11 ++++ .../sql/catalyst/parser/AstBuilder.scala | 13 +++++ 5 files changed, 78 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 5e1046293a20..80c8453ccca1 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -432,6 +432,7 @@ relationPrimary | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation | inlineTable #inlineTableDefault2 + | identifier '(' expression (',' expression)* ')' #tableValuedFunction ; inlineTable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2efa997ff22d..2e9821e8f918 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -86,6 +86,7 @@ class Analyzer( WindowsSubstitution, EliminateUnions), Batch("Resolution", fixedPoint, + ResolveTableValuedFunctions :: ResolveRelations :: ResolveReferences :: ResolveDeserializer :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala new file mode 100644 index 000000000000..1d01e84173bf --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} +import org.apache.spark.sql.catalyst.rules._ + +/** + * Rule for resolving references to table-valued functions. Currently this only resolves + * references to the hard-coded range() operator. + */ +object ResolveTableValuedFunctions extends Rule[LogicalPlan] { + private def defaultParallelism: Int = 200 // TODO(ekl) fix + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case u: UnresolvedTableValuedFunction => + // TODO(ekl) we should have a tvf registry + if (u.functionName != "range") { + u.failAnalysis(s"could not resolve `${u.functionName}` to a table valued function") + } + val evaluatedArgs = u.functionArgs.map(_.eval()) + val longArgs = evaluatedArgs.map(_.toString.toLong) // TODO(ekl) fix + longArgs match { + case Seq(end) => + Range(0, end, 1, defaultParallelism) + case Seq(start, end) => + Range(start, end, 1, defaultParallelism) + case Seq(start, end, step) => + Range(start, end, step, defaultParallelism) + case Seq(start, end, step, numPartitions) => + Range(start, end, step, numPartitions.toInt) + case _ => + u.failAnalysis(s"invalid number of argument for range(): ${longArgs}") + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 609089a302c8..575eb45c7982 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -49,6 +49,17 @@ case class UnresolvedRelation( override lazy val resolved = false } +/** + * Holds a table-valued-function call that has yet to be resolved. + */ +case class UnresolvedTableValuedFunction( + functionName: String, functionArgs: Seq[Expression]) extends LeafNode { + + override def output: Seq[Attribute] = Nil + + override lazy val resolved = false +} + /** * Holds the name of an attribute that has yet to be resolved. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f2cc8d362478..7695b8f3381c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -652,6 +652,19 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { table.optionalMap(ctx.sample)(withSample) } + /** + * Create a table-valued-function call with arguments, e.g. range(1000) + */ + override def visitTableValuedFunction(ctx: TableValuedFunctionContext) + : LogicalPlan = withOrigin(ctx) { + val expressions = ctx.expression.asScala.map { ec => + val e = expression(ec) + assert(e.foldable, "All params of a table-valued-function call must be constants.", ec) + e + } + UnresolvedTableValuedFunction(ctx.identifier.getText, expressions) + } + /** * Create an inline table (a virtual table in Hive parlance). */ From e5b0e2747e7e0bf254f173fbc0752e2e47266918 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Aug 2016 00:43:07 -0700 Subject: [PATCH 02/21] Mon Aug 15 00:43:07 PDT 2016 --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../analysis/TableValuedFunctions.scala | 103 ++++++++++++++---- .../sql/catalyst/analysis/unresolved.scala | 2 +- .../sql/catalyst/parser/AstBuilder.scala | 4 +- 4 files changed, 87 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 80c8453ccca1..da5daa130710 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -432,7 +432,7 @@ relationPrimary | '(' queryNoWith ')' sample? (AS? strictIdentifier)? #aliasedQuery | '(' relation ')' sample? (AS? strictIdentifier)? #aliasedRelation | inlineTable #inlineTableDefault2 - | identifier '(' expression (',' expression)* ')' #tableValuedFunction + | identifier '(' (expression (',' expression)*)? ')' #tableValuedFunction ; inlineTable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala index 1d01e84173bf..ae511468b8e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala @@ -17,36 +17,99 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules._ /** - * Rule for resolving references to table-valued functions. Currently this only resolves - * references to the hard-coded range() operator. + * Rule that resolves table-valued function references. */ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { - private def defaultParallelism: Int = 200 // TODO(ekl) fix + private lazy val defaultParallelism = + SparkContext.getOrCreate(new SparkConf(false)).defaultParallelism - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case u: UnresolvedTableValuedFunction => - // TODO(ekl) we should have a tvf registry - if (u.functionName != "range") { - u.failAnalysis(s"could not resolve `${u.functionName}` to a table valued function") + /** + * Type aliases for a TVF declaration. A TVF maps a sequence of named arguments to a function + * resolving the TVF given a matching list of arguments from the user. This allows for + * function overloading (e.g. range(100), range(0, 100)). + */ + private type NamedArguments = Seq[Tuple2[String, Class[_]]] + private type TVF = Map[NamedArguments, Seq[Any] => LogicalPlan] + + /** + * Internal registry of table-valued-functions. TODO(ekl) we should have a proper registry + */ + private val builtinFunctions: Map[String, TVF] = Map( + "range" -> Map( + /* range(end) */ + Seq(("end", classOf[Number])) -> ( + (args: Seq[Any]) => + Range(0, args(0).asInstanceOf[Number].longValue, 1, defaultParallelism)), + + /* range(start, end) */ + Seq(("start", classOf[Number]), ("end", classOf[Number])) -> ( + (args: Seq[Any]) => + Range( + args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, 1, + defaultParallelism)), + + /* range(start, end, step) */ + Seq(("start", classOf[Number]), ("end", classOf[Number]), ("steps", classOf[Number])) -> ( + (args: Seq[Any]) => + Range( + args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, + args(2).asInstanceOf[Number].longValue, defaultParallelism)), + + /* range(start, end, step, numPartitions) */ + Seq(("start", classOf[Number]), ("end", classOf[Number]), ("steps", classOf[Number]), + ("numPartitions", classOf[Integer])) -> ( + (args: Seq[Any]) => + Range( + args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, + args(2).asInstanceOf[Number].longValue, args(3).asInstanceOf[Integer])) + ) + ) + + /** + * Returns whether a given sequence of values can be assigned to the specified arguments. + */ + private def assignableFrom(args: NamedArguments, values: Seq[Any]): Boolean = { + if (args.length == values.length) { + args.zip(values).forall { case ((name, clazz), value) => + clazz.isAssignableFrom(value.getClass) } - val evaluatedArgs = u.functionArgs.map(_.eval()) - val longArgs = evaluatedArgs.map(_.toString.toLong) // TODO(ekl) fix - longArgs match { - case Seq(end) => - Range(0, end, 1, defaultParallelism) - case Seq(start, end) => - Range(start, end, 1, defaultParallelism) - case Seq(start, end, step) => - Range(start, end, step, defaultParallelism) - case Seq(start, end, step, numPartitions) => - Range(start, end, step, numPartitions.toInt) + } else { + false + } + } + + /** + * Formats a list of named args, e.g. to "start: Number, end: Number, steps: Number". + */ + private def formatArgs(args: NamedArguments): String = { + args.map { a => + s"${a._1}: ${a._2.getSimpleName}" + }.mkString(", ") + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case u: UnresolvedTableValuedFunction => + builtinFunctions.get(u.functionName) match { + case Some(tvf) => + val evaluatedArgs = u.functionArgs.map(_.eval()) + for ((argSpec, resolver) <- tvf) { + if (assignableFrom(argSpec, evaluatedArgs)) { + return resolver(evaluatedArgs) + } + } + val argTypes = evaluatedArgs.map(_.getClass.getSimpleName).mkString(", ") + u.failAnalysis( + s"""error: table-valued function ${u.functionName} with alternatives: + |${tvf.keys.map(formatArgs).toSeq.sorted.map(x => s" ($x)").mkString("\n")} + |cannot be applied to: (${argTypes})""".stripMargin) case _ => - u.failAnalysis(s"invalid number of argument for range(): ${longArgs}") + u.failAnalysis(s"could not resolve `${u.functionName}` to a table valued function") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 575eb45c7982..a66a551bb21f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -50,7 +50,7 @@ case class UnresolvedRelation( } /** - * Holds a table-valued-function call that has yet to be resolved. + * Holds a table-valued function call that has yet to be resolved. */ case class UnresolvedTableValuedFunction( functionName: String, functionArgs: Seq[Expression]) extends LeafNode { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7695b8f3381c..087214f2efdc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -653,13 +653,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a table-valued-function call with arguments, e.g. range(1000) + * Create a table-valued function call with arguments, e.g. range(1000) */ override def visitTableValuedFunction(ctx: TableValuedFunctionContext) : LogicalPlan = withOrigin(ctx) { val expressions = ctx.expression.asScala.map { ec => val e = expression(ec) - assert(e.foldable, "All params of a table-valued-function call must be constants.", ec) + assert(e.foldable, "All arguments of a table-valued-function must be constants.", ec) e } UnresolvedTableValuedFunction(ctx.identifier.getText, expressions) From 0d4ad5e4ade5ef4e93f5f6291b94dcea3885437f Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Aug 2016 00:43:50 -0700 Subject: [PATCH 03/21] Mon Aug 15 00:43:50 PDT 2016 --- .../spark/sql/catalyst/analysis/TableValuedFunctions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala index ae511468b8e1..b7a50fec5b3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala @@ -31,7 +31,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { /** * Type aliases for a TVF declaration. A TVF maps a sequence of named arguments to a function - * resolving the TVF given a matching list of arguments from the user. This allows for + * resolving the TVF given a matching sequence of values from the user. This allows for * function overloading (e.g. range(100), range(0, 100)). */ private type NamedArguments = Seq[Tuple2[String, Class[_]]] From 4d9eb53d50a59da65ebd5a4480762cf6c1c9d39b Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Aug 2016 00:44:07 -0700 Subject: [PATCH 04/21] Mon Aug 15 00:44:07 PDT 2016 --- .../spark/sql/catalyst/analysis/TableValuedFunctions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala index b7a50fec5b3c..9ab78774534b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala @@ -31,7 +31,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { /** * Type aliases for a TVF declaration. A TVF maps a sequence of named arguments to a function - * resolving the TVF given a matching sequence of values from the user. This allows for + * resolving the TVF given a matching sequence of values from the user. Using a map allows for * function overloading (e.g. range(100), range(0, 100)). */ private type NamedArguments = Seq[Tuple2[String, Class[_]]] From 21078e6633cc8e03d374fd5504c9d70bfa4eebc5 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Aug 2016 00:44:26 -0700 Subject: [PATCH 05/21] Mon Aug 15 00:44:26 PDT 2016 --- .../spark/sql/catalyst/analysis/TableValuedFunctions.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala index 9ab78774534b..276f2ba1a9ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala @@ -67,9 +67,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { (args: Seq[Any]) => Range( args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, - args(2).asInstanceOf[Number].longValue, args(3).asInstanceOf[Integer])) - ) - ) + args(2).asInstanceOf[Number].longValue, args(3).asInstanceOf[Integer])))) /** * Returns whether a given sequence of values can be assigned to the specified arguments. From cc987d38ba06fcdaf001cd59804ec46c7a90bada Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Aug 2016 01:08:40 -0700 Subject: [PATCH 06/21] Mon Aug 15 01:08:40 PDT 2016 --- .../analysis/TableValuedFunctions.scala | 74 ++++++++++--------- 1 file changed, 38 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala index 276f2ba1a9ab..7584122223fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala @@ -30,12 +30,34 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { SparkContext.getOrCreate(new SparkConf(false)).defaultParallelism /** - * Type aliases for a TVF declaration. A TVF maps a sequence of named arguments to a function - * resolving the TVF given a matching sequence of values from the user. Using a map allows for - * function overloading (e.g. range(100), range(0, 100)). + * List of argument names and their types, used to declare a function call. */ - private type NamedArguments = Seq[Tuple2[String, Class[_]]] - private type TVF = Map[NamedArguments, Seq[Any] => LogicalPlan] + private case class ArgumentList(args: (String, Class[_])*) { + /** + * @return whether this list is assignable from the given sequence of values. + */ + def assignableFrom(values: Seq[Any]): Boolean = { + if (args.length == values.length) { + args.zip(values).forall { case ((name, clazz), value) => + clazz.isAssignableFrom(value.getClass) + } + } else { + false + } + } + + override def toString: String = { + args.map { a => + s"${a._1}: ${a._2.getSimpleName}" + }.mkString(", ") + } + } + + /** + * A TVF maps argument lists to a resolving functions that accept those arguments. Using a map + * here allows for function overloading. + */ + private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan] /** * Internal registry of table-valued-functions. TODO(ekl) we should have a proper registry @@ -43,68 +65,48 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { private val builtinFunctions: Map[String, TVF] = Map( "range" -> Map( /* range(end) */ - Seq(("end", classOf[Number])) -> ( + ArgumentList(("end", classOf[Number])) -> ( (args: Seq[Any]) => Range(0, args(0).asInstanceOf[Number].longValue, 1, defaultParallelism)), /* range(start, end) */ - Seq(("start", classOf[Number]), ("end", classOf[Number])) -> ( + ArgumentList(("start", classOf[Number]), ("end", classOf[Number])) -> ( (args: Seq[Any]) => Range( args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, 1, defaultParallelism)), /* range(start, end, step) */ - Seq(("start", classOf[Number]), ("end", classOf[Number]), ("steps", classOf[Number])) -> ( + ArgumentList(("start", classOf[Number]), ("end", classOf[Number]), + ("steps", classOf[Number])) -> ( (args: Seq[Any]) => Range( args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, args(2).asInstanceOf[Number].longValue, defaultParallelism)), /* range(start, end, step, numPartitions) */ - Seq(("start", classOf[Number]), ("end", classOf[Number]), ("steps", classOf[Number]), - ("numPartitions", classOf[Integer])) -> ( + ArgumentList(("start", classOf[Number]), ("end", classOf[Number]), + ("steps", classOf[Number]), ("numPartitions", classOf[Integer])) -> ( (args: Seq[Any]) => Range( args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, - args(2).asInstanceOf[Number].longValue, args(3).asInstanceOf[Integer])))) - - /** - * Returns whether a given sequence of values can be assigned to the specified arguments. - */ - private def assignableFrom(args: NamedArguments, values: Seq[Any]): Boolean = { - if (args.length == values.length) { - args.zip(values).forall { case ((name, clazz), value) => - clazz.isAssignableFrom(value.getClass) - } - } else { - false - } - } - - /** - * Formats a list of named args, e.g. to "start: Number, end: Number, steps: Number". - */ - private def formatArgs(args: NamedArguments): String = { - args.map { a => - s"${a._1}: ${a._2.getSimpleName}" - }.mkString(", ") - } + args(2).asInstanceOf[Number].longValue, args(3).asInstanceOf[Integer]))) + ) override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction => builtinFunctions.get(u.functionName) match { case Some(tvf) => val evaluatedArgs = u.functionArgs.map(_.eval()) - for ((argSpec, resolver) <- tvf) { - if (assignableFrom(argSpec, evaluatedArgs)) { + for ((argList, resolver) <- tvf) { + if (argList.assignableFrom(evaluatedArgs)) { return resolver(evaluatedArgs) } } val argTypes = evaluatedArgs.map(_.getClass.getSimpleName).mkString(", ") u.failAnalysis( s"""error: table-valued function ${u.functionName} with alternatives: - |${tvf.keys.map(formatArgs).toSeq.sorted.map(x => s" ($x)").mkString("\n")} + |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} |cannot be applied to: (${argTypes})""".stripMargin) case _ => u.failAnalysis(s"could not resolve `${u.functionName}` to a table valued function") From 92e668f94a57440beb6f9a40616e57d098794ab0 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Aug 2016 01:21:44 -0700 Subject: [PATCH 07/21] Mon Aug 15 01:21:44 PDT 2016 --- .../spark/sql/TableValuedFunctionSuite.scala | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TableValuedFunctionSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TableValuedFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TableValuedFunctionSuite.scala new file mode 100644 index 000000000000..6ea39fc313c3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TableValuedFunctionSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext + +class TableValuedFunctionSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("built-in range") { + checkAnswer( + sql("select * from range(3)"), + Row(0) :: Row(1) :: Row(2) :: Nil) + + checkAnswer( + sql("select count(*) from range(100)"), + Row(100)) + + checkAnswer( + sql("select count(*) from range(10, 100)"), + Row(90)) + + checkAnswer( + sql("select count(*) from range(10, 100, 2)"), + Row(45)) + + assert(sql("select * from range(10, 100, 2, 12)").rdd.getNumPartitions == 12) + + val error = intercept[AnalysisException] { + sql("select count(*) from range(2, 'x')") + } + assert( + error.getMessage.contains("""error: table-valued function range with alternatives: + | (end: Number) + | (start: Number, end: Number) + | (start: Number, end: Number, steps: Number) + | (start: Number, end: Number, steps: Number, numPartitions: Integer) + |cannot be applied to: (Integer, UTF8String)""".stripMargin)) + } +} From d1fd2b9eb24a74dc2bab61f2cb05940bb1f91b11 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Aug 2016 01:22:29 -0700 Subject: [PATCH 08/21] Mon Aug 15 01:22:29 PDT 2016 --- .../spark/sql/catalyst/analysis/TableValuedFunctions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala index 7584122223fe..99caf07c1128 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala @@ -30,7 +30,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { SparkContext.getOrCreate(new SparkConf(false)).defaultParallelism /** - * List of argument names and their types, used to declare a function call. + * List of argument names and their types, used to declare a function. */ private case class ArgumentList(args: (String, Class[_])*) { /** From fbf515f80c7fc4f87c4992afb55a0ea3d6f6e3dc Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Aug 2016 01:22:58 -0700 Subject: [PATCH 09/21] Mon Aug 15 01:22:58 PDT 2016 --- .../spark/sql/catalyst/analysis/TableValuedFunctions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala index 99caf07c1128..b273c893cca6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala @@ -54,7 +54,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { } /** - * A TVF maps argument lists to a resolving functions that accept those arguments. Using a map + * A TVF maps argument lists to resolving functions that accept those arguments. Using a map * here allows for function overloading. */ private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan] From 29f953853405d3c69c30951972c72e4f8dcd8584 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Aug 2016 01:31:47 -0700 Subject: [PATCH 10/21] Mon Aug 15 01:31:47 PDT 2016 --- .../sql/catalyst/analysis/TableValuedFunctions.scala | 4 ++-- .../org/apache/spark/sql/TableValuedFunctionSuite.scala | 8 +++++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala index b273c893cca6..904ed3f202f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala @@ -60,7 +60,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan] /** - * Internal registry of table-valued-functions. TODO(ekl) we should have a proper registry + * Internal registry of table-valued functions. TODO(ekl) we should have a proper registry */ private val builtinFunctions: Map[String, TVF] = Map( "range" -> Map( @@ -109,7 +109,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} |cannot be applied to: (${argTypes})""".stripMargin) case _ => - u.failAnalysis(s"could not resolve `${u.functionName}` to a table valued function") + u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TableValuedFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TableValuedFunctionSuite.scala index 6ea39fc313c3..1ba9a36f981b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TableValuedFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TableValuedFunctionSuite.scala @@ -20,7 +20,13 @@ package org.apache.spark.sql import org.apache.spark.sql.test.SharedSQLContext class TableValuedFunctionSuite extends QueryTest with SharedSQLContext { - import testImplicits._ + + test("unresolvable function") { + val error = intercept[AnalysisException] { + sql("select * from dummy(3)") + } + assert(error.getMessage.contains("could not resolve `dummy` to a table-valued function")) + } test("built-in range") { checkAnswer( From e9251f56b7b72f51b45e99762ac7ac5f5492953e Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Aug 2016 01:50:02 -0700 Subject: [PATCH 11/21] Mon Aug 15 01:50:02 PDT 2016 --- .../spark/sql/catalyst/analysis/TableValuedFunctions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala index 904ed3f202f7..f3a9bcfa1b62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala @@ -78,7 +78,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { /* range(start, end, step) */ ArgumentList(("start", classOf[Number]), ("end", classOf[Number]), - ("steps", classOf[Number])) -> ( + ("step", classOf[Number])) -> ( (args: Seq[Any]) => Range( args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, @@ -86,7 +86,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { /* range(start, end, step, numPartitions) */ ArgumentList(("start", classOf[Number]), ("end", classOf[Number]), - ("steps", classOf[Number]), ("numPartitions", classOf[Integer])) -> ( + ("step", classOf[Number]), ("numPartitions", classOf[Integer])) -> ( (args: Seq[Any]) => Range( args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, From 1f81b97b81f5ee859849e50021df094e03f2704c Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Aug 2016 01:52:04 -0700 Subject: [PATCH 12/21] Mon Aug 15 01:52:04 PDT 2016 --- .../spark/sql/catalyst/analysis/TableValuedFunctions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala index f3a9bcfa1b62..556a815c74be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala @@ -54,7 +54,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { } /** - * A TVF maps argument lists to resolving functions that accept those arguments. Using a map + * A TVF maps argument lists to resolver functions that accept those arguments. Using a map * here allows for function overloading. */ private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan] From a444b9efa1d4f3fc00e51076ba8195d2571cf2c0 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Aug 2016 16:31:02 -0700 Subject: [PATCH 13/21] Mon Aug 15 16:31:02 PDT 2016 --- ...bleValuedFunctions.scala => ResolveTableValuedFunctions.scala} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/{TableValuedFunctions.scala => ResolveTableValuedFunctions.scala} (100%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableValuedFunctions.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala From 7bac3af6748e63148368f53f088d765ceef69ac1 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Aug 2016 16:43:02 -0700 Subject: [PATCH 14/21] Mon Aug 15 16:43:02 PDT 2016 --- .../inputs/table-valued-functions.sql | 17 ++++ .../results/table-valued-functions.sql.out | 78 +++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql new file mode 100644 index 000000000000..7893ec2dd05a --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -0,0 +1,17 @@ +-- unresolved function +select * from dummy(3); + +-- range call with end +select * from range(5); + +-- range call with start and end +select * from range(5, 10); + +-- range call with step +select * from range(0, 10, 2); + +-- range call with numPartitions +select * from range(0, 10, 1, 200); + +-- range call error +select * from range(2, 'x'); diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out new file mode 100644 index 000000000000..c77f77b0fbf5 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -0,0 +1,78 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +select * from dummy(3) +-- !query 0 schema +struct<> +-- !query 0 output +org.apache.spark.sql.AnalysisException +could not resolve `dummy` to a table-valued function; line 1 pos 14 + + +-- !query 1 +select * from range(5) +-- !query 1 schema +struct +-- !query 1 output +0 +1 +2 +3 +4 + + +-- !query 2 +select * from range(5, 10) +-- !query 2 schema +struct +-- !query 2 output +5 +6 +7 +8 +9 + + +-- !query 3 +select * from range(0, 10, 2) +-- !query 3 schema +struct +-- !query 3 output +0 +2 +4 +6 +8 + + +-- !query 4 +select * from range(0, 10, 1, 200) +-- !query 4 schema +struct +-- !query 4 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 5 +select * from range(2, 'x') +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +error: table-valued function range with alternatives: + (end: Number) + (start: Number, end: Number) + (start: Number, end: Number, step: Number) + (start: Number, end: Number, step: Number, numPartitions: Integer) +cannot be applied to: (Integer, UTF8String); line 1 pos 14 From 78c9f05b6e3babdd4f935b3e39c64304de90fde9 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Aug 2016 16:43:27 -0700 Subject: [PATCH 15/21] Mon Aug 15 16:43:27 PDT 2016 --- .../spark/sql/TableValuedFunctionSuite.scala | 61 ------------------- 1 file changed, 61 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TableValuedFunctionSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TableValuedFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TableValuedFunctionSuite.scala deleted file mode 100644 index 1ba9a36f981b..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/TableValuedFunctionSuite.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.test.SharedSQLContext - -class TableValuedFunctionSuite extends QueryTest with SharedSQLContext { - - test("unresolvable function") { - val error = intercept[AnalysisException] { - sql("select * from dummy(3)") - } - assert(error.getMessage.contains("could not resolve `dummy` to a table-valued function")) - } - - test("built-in range") { - checkAnswer( - sql("select * from range(3)"), - Row(0) :: Row(1) :: Row(2) :: Nil) - - checkAnswer( - sql("select count(*) from range(100)"), - Row(100)) - - checkAnswer( - sql("select count(*) from range(10, 100)"), - Row(90)) - - checkAnswer( - sql("select count(*) from range(10, 100, 2)"), - Row(45)) - - assert(sql("select * from range(10, 100, 2, 12)").rdd.getNumPartitions == 12) - - val error = intercept[AnalysisException] { - sql("select count(*) from range(2, 'x')") - } - assert( - error.getMessage.contains("""error: table-valued function range with alternatives: - | (end: Number) - | (start: Number, end: Number) - | (start: Number, end: Number, steps: Number) - | (start: Number, end: Number, steps: Number, numPartitions: Integer) - |cannot be applied to: (Integer, UTF8String)""".stripMargin)) - } -} From 4d94ac034d5138fb9015b134c39aff5647ab21ba Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Aug 2016 17:47:52 -0700 Subject: [PATCH 16/21] Mon Aug 15 17:47:52 PDT 2016 --- .../spark/sql/catalyst/parser/PlanParserSuite.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 34d52c75e0af..730bac5ba176 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator +import org.apache.spark.sql.catalyst.analysis.{UnresolvedGenerator, UnresolvedTableValuedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -426,6 +426,12 @@ class PlanParserSuite extends PlanTest { assertEqual("table d.t", table("d", "t")) } + test("table valued function") { + assertEqual( + "select * from range(2)", + UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star())) + } + test("inline table") { assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( Seq('col1.int), From 2f80f549dd3d765fd3fdc63f9795d8e5562e38fa Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 15 Aug 2016 23:06:33 -0700 Subject: [PATCH 17/21] Mon Aug 15 23:06:33 PDT 2016 --- .../ResolveTableValuedFunctions.scala | 53 ++++++++++--------- .../inputs/table-valued-functions.sql | 2 +- .../results/table-valued-functions.sql.out | 12 ++--- 3 files changed, 35 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index 556a815c74be..87e69730ef44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.types.{DataType, IntegerType, LongType} /** * Rule that resolves table-valued function references. @@ -32,23 +34,26 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { /** * List of argument names and their types, used to declare a function. */ - private case class ArgumentList(args: (String, Class[_])*) { + private case class ArgumentList(args: (String, DataType)*) { /** - * @return whether this list is assignable from the given sequence of values. + * Try to cast the expressions to satisfy the expected types of this argument list. If there + * are any types that cannot be casted, then None is returned. */ - def assignableFrom(values: Seq[Any]): Boolean = { + def implicitCast(values: Seq[Expression]): Option[Seq[Expression]] = { if (args.length == values.length) { - args.zip(values).forall { case ((name, clazz), value) => - clazz.isAssignableFrom(value.getClass) + val casted = values.zip(args).map { case (value, (_, expectedType)) => + TypeCoercion.ImplicitTypeCasts.implicitCast(value, expectedType) + } + if (casted.forall(_.isDefined)) { + return Some(casted.map(_.get)) } - } else { - false } + None } override def toString: String = { args.map { a => - s"${a._1}: ${a._2.getSimpleName}" + s"${a._1}: ${a._2.typeName}" }.mkString(", ") } } @@ -65,45 +70,43 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { private val builtinFunctions: Map[String, TVF] = Map( "range" -> Map( /* range(end) */ - ArgumentList(("end", classOf[Number])) -> ( + ArgumentList(("end", LongType)) -> ( (args: Seq[Any]) => - Range(0, args(0).asInstanceOf[Number].longValue, 1, defaultParallelism)), + Range(0, args(0).asInstanceOf[Long], 1, defaultParallelism)), /* range(start, end) */ - ArgumentList(("start", classOf[Number]), ("end", classOf[Number])) -> ( + ArgumentList(("start", LongType), ("end", LongType)) -> ( (args: Seq[Any]) => Range( - args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, 1, - defaultParallelism)), + args(0).asInstanceOf[Long], args(1).asInstanceOf[Long], 1, defaultParallelism)), /* range(start, end, step) */ - ArgumentList(("start", classOf[Number]), ("end", classOf[Number]), - ("step", classOf[Number])) -> ( + ArgumentList(("start", LongType), ("end", LongType), ("step", LongType)) -> ( (args: Seq[Any]) => Range( - args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, - args(2).asInstanceOf[Number].longValue, defaultParallelism)), + args(0).asInstanceOf[Long], args(1).asInstanceOf[Long], args(2).asInstanceOf[Long], + defaultParallelism)), /* range(start, end, step, numPartitions) */ - ArgumentList(("start", classOf[Number]), ("end", classOf[Number]), - ("step", classOf[Number]), ("numPartitions", classOf[Integer])) -> ( + ArgumentList(("start", LongType), ("end", LongType), ("step", LongType), + ("numPartitions", IntegerType)) -> ( (args: Seq[Any]) => Range( - args(0).asInstanceOf[Number].longValue, args(1).asInstanceOf[Number].longValue, - args(2).asInstanceOf[Number].longValue, args(3).asInstanceOf[Integer]))) + args(0).asInstanceOf[Long], args(1).asInstanceOf[Long], args(2).asInstanceOf[Long], + args(3).asInstanceOf[Integer]))) ) override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction => builtinFunctions.get(u.functionName) match { case Some(tvf) => - val evaluatedArgs = u.functionArgs.map(_.eval()) for ((argList, resolver) <- tvf) { - if (argList.assignableFrom(evaluatedArgs)) { - return resolver(evaluatedArgs) + val casted = argList.implicitCast(u.functionArgs) + if (casted.isDefined) { + return resolver(casted.get.map(_.eval())) } } - val argTypes = evaluatedArgs.map(_.getClass.getSimpleName).mkString(", ") + val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ") u.failAnalysis( s"""error: table-valued function ${u.functionName} with alternatives: |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index 7893ec2dd05a..f1e254cf6d4f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -14,4 +14,4 @@ select * from range(0, 10, 2); select * from range(0, 10, 1, 200); -- range call error -select * from range(2, 'x'); +select * from range(1, 1, 1, 1, 1); diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index c77f77b0fbf5..c2e4d4f59e3d 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -65,14 +65,14 @@ struct -- !query 5 -select * from range(2, 'x') +select * from range(1, 1, 1, 1, 1) -- !query 5 schema struct<> -- !query 5 output org.apache.spark.sql.AnalysisException error: table-valued function range with alternatives: - (end: Number) - (start: Number, end: Number) - (start: Number, end: Number, step: Number) - (start: Number, end: Number, step: Number, numPartitions: Integer) -cannot be applied to: (Integer, UTF8String); line 1 pos 14 + (end: long) + (start: long, end: long) + (start: long, end: long, step: long) + (start: long, end: long, step: long, numPartitions: integer) +cannot be applied to: (integer, integer, integer, integer, integer); line 1 pos 14 From 8e03f51475797b8f4de249401dabbaaf9cf43f07 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 16 Aug 2016 13:36:54 -0700 Subject: [PATCH 18/21] wait for resolution --- .../catalyst/analysis/ResolveTableValuedFunctions.scala | 2 +- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 7 +------ .../resources/sql-tests/inputs/table-valued-functions.sql | 2 +- .../sql-tests/results/table-valued-functions.sql.out | 2 +- 4 files changed, 4 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index 87e69730ef44..d6a28708fd9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -97,7 +97,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { ) override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case u: UnresolvedTableValuedFunction => + case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => builtinFunctions.get(u.functionName) match { case Some(tvf) => for ((argList, resolver) <- tvf) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 380543647362..dfc885044cca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -662,12 +662,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ override def visitTableValuedFunction(ctx: TableValuedFunctionContext) : LogicalPlan = withOrigin(ctx) { - val expressions = ctx.expression.asScala.map { ec => - val e = expression(ec) - assert(e.foldable, "All arguments of a table-valued-function must be constants.", ec) - e - } - UnresolvedTableValuedFunction(ctx.identifier.getText, expressions) + UnresolvedTableValuedFunction(ctx.identifier.getText, ctx.expression.asScala.map(expression)) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index f1e254cf6d4f..5cc85ca43692 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -2,7 +2,7 @@ select * from dummy(3); -- range call with end -select * from range(5); +select * from range(6 + cos(3)); -- range call with start and end select * from range(5, 10); diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index c2e4d4f59e3d..5e7bc82376fb 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -12,7 +12,7 @@ could not resolve `dummy` to a table-valued function; line 1 pos 14 -- !query 1 -select * from range(5) +select * from range(6 + cos(3)) -- !query 1 schema struct -- !query 1 output From 1fa57c4d63c8ece438fde7f8d23d0b0698d22cd9 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 16 Aug 2016 16:00:21 -0700 Subject: [PATCH 19/21] partial functions --- .../ResolveTableValuedFunctions.scala | 41 ++++++++++--------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index d6a28708fd9c..fbc7f449de3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -65,35 +65,38 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { private type TVF = Map[ArgumentList, Seq[Any] => LogicalPlan] /** - * Internal registry of table-valued functions. TODO(ekl) we should have a proper registry + * TVF builder. + */ + private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], LogicalPlan]) + : (ArgumentList, Seq[Any] => LogicalPlan) = (ArgumentList(args: _*), pf) + + /** + * Internal registry of table-valued functions. */ private val builtinFunctions: Map[String, TVF] = Map( "range" -> Map( /* range(end) */ - ArgumentList(("end", LongType)) -> ( - (args: Seq[Any]) => - Range(0, args(0).asInstanceOf[Long], 1, defaultParallelism)), + tvf("end" -> LongType) { case Seq(end: Long) => + Range(0, end, 1, defaultParallelism) + }, /* range(start, end) */ - ArgumentList(("start", LongType), ("end", LongType)) -> ( - (args: Seq[Any]) => - Range( - args(0).asInstanceOf[Long], args(1).asInstanceOf[Long], 1, defaultParallelism)), + tvf("start" -> LongType, "end" -> LongType) { case Seq(start: Long, end: Long) => + Range(start, end, 1, defaultParallelism) + }, /* range(start, end, step) */ - ArgumentList(("start", LongType), ("end", LongType), ("step", LongType)) -> ( - (args: Seq[Any]) => - Range( - args(0).asInstanceOf[Long], args(1).asInstanceOf[Long], args(2).asInstanceOf[Long], - defaultParallelism)), + tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) { + case Seq(start: Long, end: Long, step: Long) => + Range(start, end, step, defaultParallelism) + }, /* range(start, end, step, numPartitions) */ - ArgumentList(("start", LongType), ("end", LongType), ("step", LongType), - ("numPartitions", IntegerType)) -> ( - (args: Seq[Any]) => - Range( - args(0).asInstanceOf[Long], args(1).asInstanceOf[Long], args(2).asInstanceOf[Long], - args(3).asInstanceOf[Integer]))) + tvf("start" -> LongType, "end" -> LongType, "step" -> LongType, + "numPartitions" -> IntegerType) { + case Seq(start: Long, end: Long, step: Long, numPartitions: Int) => + Range(start, end, step, numPartitions) + }) ) override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { From f8831cae68bb553f3b86933ed153979cf8d0c02c Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 17 Aug 2016 19:50:40 -0700 Subject: [PATCH 20/21] Wed Aug 17 19:50:40 PDT 2016 --- .../ResolveTableValuedFunctions.scala | 39 ++++++++++++------- .../inputs/table-valued-functions.sql | 3 ++ .../results/table-valued-functions.sql.out | 11 +++++- 3 files changed, 38 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index fbc7f449de3e..7fdf7fa0c06a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -68,7 +68,14 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { * TVF builder. */ private def tvf(args: (String, DataType)*)(pf: PartialFunction[Seq[Any], LogicalPlan]) - : (ArgumentList, Seq[Any] => LogicalPlan) = (ArgumentList(args: _*), pf) + : (ArgumentList, Seq[Any] => LogicalPlan) = { + (ArgumentList(args: _*), + pf orElse { + case args => + throw new IllegalArgumentException( + "Invalid arguments for resolved function: " + args.mkString(", ")) + }) + } /** * Internal registry of table-valued functions. @@ -87,15 +94,15 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { /* range(start, end, step) */ tvf("start" -> LongType, "end" -> LongType, "step" -> LongType) { - case Seq(start: Long, end: Long, step: Long) => - Range(start, end, step, defaultParallelism) + case Seq(start: Long, end: Long, step: Long) => + Range(start, end, step, defaultParallelism) }, /* range(start, end, step, numPartitions) */ tvf("start" -> LongType, "end" -> LongType, "step" -> LongType, "numPartitions" -> IntegerType) { - case Seq(start: Long, end: Long, step: Long, numPartitions: Int) => - Range(start, end, step, numPartitions) + case Seq(start: Long, end: Long, step: Long, numPartitions: Int) => + Range(start, end, step, numPartitions) }) ) @@ -103,17 +110,21 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => builtinFunctions.get(u.functionName) match { case Some(tvf) => - for ((argList, resolver) <- tvf) { - val casted = argList.implicitCast(u.functionArgs) - if (casted.isDefined) { - return resolver(casted.get.map(_.eval())) + val resolved = tvf.flatMap { case (argList, resolver) => + argList.implicitCast(u.functionArgs) match { + case Some(casted) => + Some(resolver(casted.map(_.eval()))) + case _ => + None } } - val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ") - u.failAnalysis( - s"""error: table-valued function ${u.functionName} with alternatives: - |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} - |cannot be applied to: (${argTypes})""".stripMargin) + resolved.headOption.getOrElse { + val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ") + u.failAnalysis( + s"""error: table-valued function ${u.functionName} with alternatives: + |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")} + |cannot be applied to: (${argTypes})""".stripMargin) + } case _ => u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function") } diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index 5cc85ca43692..c2a405f87324 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -15,3 +15,6 @@ select * from range(0, 10, 1, 200); -- range call error select * from range(1, 1, 1, 1, 1); + +-- range call with null +select * from range(null); diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index 5e7bc82376fb..d769bcef0aca 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 6 +-- Number of queries: 7 -- !query 0 @@ -76,3 +76,12 @@ error: table-valued function range with alternatives: (start: long, end: long, step: long) (start: long, end: long, step: long, numPartitions: integer) cannot be applied to: (integer, integer, integer, integer, integer); line 1 pos 14 + + +-- !query 6 +select * from range(1, null) +-- !query 6 schema +struct<> +-- !query 6 output +java.lang.IllegalArgumentException +Invalid arguments for resolved function: 1, null From 7ebd563fca73ae4a4e05970709f334a4d09b5ff1 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 17 Aug 2016 19:50:55 -0700 Subject: [PATCH 21/21] Wed Aug 17 19:50:55 PDT 2016 --- .../test/resources/sql-tests/inputs/table-valued-functions.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql index c2a405f87324..2e6dcd538b7a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql @@ -17,4 +17,4 @@ select * from range(0, 10, 1, 200); select * from range(1, 1, 1, 1, 1); -- range call with null -select * from range(null); +select * from range(1, null);