From 2895b3b07426855220336908ad50c253a9848efe Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Sat, 2 Nov 2019 16:57:50 +0800 Subject: [PATCH 1/2] [SPARK-29722][SQL] Non reversed keywords should be able to be used in high order functions --- .../spark/sql/catalyst/parser/SqlBase.g4 | 4 +- .../sql/catalyst/parser/AstBuilder.scala | 2 +- .../inputs/higher-order-functions.sql | 5 + .../results/higher-order-functions.sql.out | 156 ++++++++++-------- 4 files changed, 96 insertions(+), 71 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 96d1e42ffafe..f0a040bceb28 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -718,8 +718,8 @@ primaryExpression | '(' query ')' #subqueryExpression | qualifiedName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')' (OVER windowSpec)? #functionCall - | IDENTIFIER '->' expression #lambda - | '(' IDENTIFIER (',' IDENTIFIER)+ ')' '->' expression #lambda + | identifier '->' expression #lambda + | '(' identifier (',' identifier)+ ')' '->' expression #lambda | value=primaryExpression '[' index=valueExpression ']' #subscript | identifier #columnReference | base=primaryExpression '.' fieldName=identifier #dereference diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 838fc4d84a5d..0bc87095ce53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1559,7 +1559,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * Create an [[LambdaFunction]]. */ override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) { - val arguments = ctx.IDENTIFIER().asScala.map { name => + val arguments = ctx.identifier().asScala.map { name => UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts) } val function = expression(ctx.expression).transformUp { diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index 02ad5e353868..d8390cc49858 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -12,9 +12,14 @@ select transform(zs, z -> z) as v from nested; -- Transform an array select transform(ys, y -> y * y) as v from nested; +-- use non reversed keywords +select transform(ys, left -> left * left) as v from nested; + -- Transform an array with index select transform(ys, (y, i) -> y + i) as v from nested; +-- use non reversed keywords +select transform(ys, (cost, i) -> cost + i) as v from nested; -- Transform an array with reference select transform(zs, z -> concat(ys, z)) as v from nested; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index 1b7c6f4f7625..d418b887924f 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 27 +-- Number of queries: 29 -- !query 0 @@ -44,213 +44,233 @@ struct> -- !query 4 -select transform(ys, (y, i) -> y + i) as v from nested +select transform(ys, left -> left * left) as v from nested -- !query 4 schema struct> -- !query 4 output +[1024,9409] +[144] +[5929,5776] + + +-- !query 5 +select transform(ys, (y, i) -> y + i) as v from nested +-- !query 5 schema +struct> +-- !query 5 output [12] [32,98] [77,-75] --- !query 5 +-- !query 6 +select transform(ys, (cost, i) -> cost + i) as v from nested +-- !query 6 schema +struct> +-- !query 6 output +[12] +[32,98] +[77,-75] + + +-- !query 7 select transform(zs, z -> concat(ys, z)) as v from nested --- !query 5 schema +-- !query 7 schema struct>> --- !query 5 output +-- !query 7 output [[12,17]] [[32,97,12,99],[32,97,123,42],[32,97,1]] [[77,-76,6,96,65],[77,-76,-1,-2]] --- !query 6 +-- !query 8 select transform(ys, 0) as v from nested --- !query 6 schema +-- !query 8 schema struct> --- !query 6 output +-- !query 8 output [0,0] [0,0] [0] --- !query 7 +-- !query 9 select transform(cast(null as array), x -> x + 1) as v --- !query 7 schema +-- !query 9 schema struct> --- !query 7 output +-- !query 9 output NULL --- !query 8 +-- !query 10 select filter(ys, y -> y > 30) as v from nested --- !query 8 schema +-- !query 10 schema struct> --- !query 8 output +-- !query 10 output [32,97] [77] [] --- !query 9 +-- !query 11 select filter(cast(null as array), y -> true) as v --- !query 9 schema +-- !query 11 schema struct> --- !query 9 output +-- !query 11 output NULL --- !query 10 +-- !query 12 select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested --- !query 10 schema +-- !query 12 schema struct>> --- !query 10 output +-- !query 12 output [[96,65],[]] [[99],[123],[]] [[]] --- !query 11 +-- !query 13 select aggregate(ys, 0, (y, a) -> y + a + x) as v from nested --- !query 11 schema +-- !query 13 schema struct --- !query 11 output +-- !query 13 output 131 15 5 --- !query 12 +-- !query 14 select aggregate(ys, (0 as sum, 0 as n), (acc, x) -> (acc.sum + x, acc.n + 1), acc -> acc.sum / acc.n) as v from nested --- !query 12 schema +-- !query 14 schema struct --- !query 12 output +-- !query 14 output 0.5 12.0 64.5 --- !query 13 +-- !query 15 select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as v from nested --- !query 13 schema +-- !query 15 schema struct> --- !query 13 output +-- !query 15 output [1010880,8] [17] [4752,20664,1] --- !query 14 +-- !query 16 select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) as v --- !query 14 schema +-- !query 16 schema struct --- !query 14 output +-- !query 16 output NULL --- !query 15 +-- !query 17 select exists(ys, y -> y > 30) as v from nested --- !query 15 schema +-- !query 17 schema struct --- !query 15 output +-- !query 17 output false true true --- !query 16 +-- !query 18 select exists(cast(null as array), y -> y > 30) as v --- !query 16 schema +-- !query 18 schema struct --- !query 16 output +-- !query 18 output NULL --- !query 17 +-- !query 19 select zip_with(ys, zs, (a, b) -> a + size(b)) as v from nested --- !query 17 schema +-- !query 19 schema struct> --- !query 17 output +-- !query 19 output [13] [34,99,null] [80,-74] --- !query 18 +-- !query 20 select zip_with(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)) as v --- !query 18 schema +-- !query 20 schema struct> --- !query 18 output +-- !query 20 output ["ad","be","cf"] --- !query 19 +-- !query 21 select zip_with(array('a'), array('d', null, 'f'), (x, y) -> coalesce(x, y)) as v --- !query 19 schema +-- !query 21 schema struct> --- !query 19 output +-- !query 21 output ["a",null,"f"] --- !query 20 +-- !query 22 create or replace temporary view nested as values (1, map(1, 1, 2, 2, 3, 3)), (2, map(4, 4, 5, 5, 6, 6)) as t(x, ys) --- !query 20 schema +-- !query 22 schema struct<> --- !query 20 output +-- !query 22 output --- !query 21 +-- !query 23 select transform_keys(ys, (k, v) -> k) as v from nested --- !query 21 schema +-- !query 23 schema struct> --- !query 21 output +-- !query 23 output {1:1,2:2,3:3} {4:4,5:5,6:6} --- !query 22 +-- !query 24 select transform_keys(ys, (k, v) -> k + 1) as v from nested --- !query 22 schema +-- !query 24 schema struct> --- !query 22 output +-- !query 24 output {2:1,3:2,4:3} {5:4,6:5,7:6} --- !query 23 +-- !query 25 select transform_keys(ys, (k, v) -> k + v) as v from nested --- !query 23 schema +-- !query 25 schema struct> --- !query 23 output +-- !query 25 output {10:5,12:6,8:4} {2:1,4:2,6:3} --- !query 24 +-- !query 26 select transform_values(ys, (k, v) -> v) as v from nested --- !query 24 schema +-- !query 26 schema struct> --- !query 24 output +-- !query 26 output {1:1,2:2,3:3} {4:4,5:5,6:6} --- !query 25 +-- !query 27 select transform_values(ys, (k, v) -> v + 1) as v from nested --- !query 25 schema +-- !query 27 schema struct> --- !query 25 output +-- !query 27 output {1:2,2:3,3:4} {4:5,5:6,6:7} --- !query 26 +-- !query 28 select transform_values(ys, (k, v) -> k + v) as v from nested --- !query 26 schema +-- !query 28 schema struct> --- !query 26 output +-- !query 28 output {1:2,2:4,3:6} {4:8,5:10,6:12} From 5db6d6d528d9e7fd88a1d3301cfbba877a38cfb7 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Sun, 3 Nov 2019 17:05:57 +0800 Subject: [PATCH 2/2] mv and add some ut --- .../inputs/higher-order-functions.sql | 14 +- .../results/higher-order-functions.sql.out | 216 +++++++++++------- 2 files changed, 137 insertions(+), 93 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index d8390cc49858..8d5d9fae7a73 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -12,14 +12,9 @@ select transform(zs, z -> z) as v from nested; -- Transform an array select transform(ys, y -> y * y) as v from nested; --- use non reversed keywords -select transform(ys, left -> left * left) as v from nested; - -- Transform an array with index select transform(ys, (y, i) -> y + i) as v from nested; --- use non reversed keywords -select transform(ys, (cost, i) -> cost + i) as v from nested; -- Transform an array with reference select transform(zs, z -> concat(ys, z)) as v from nested; @@ -88,3 +83,12 @@ select transform_values(ys, (k, v) -> v + 1) as v from nested; -- Transform values in a map using values select transform_values(ys, (k, v) -> k + v) as v from nested; + +-- use non reversed keywords: all is non reversed only if !ansi +select transform(ys, all -> all * all) as v from values (array(32, 97)) as t(ys); +select transform(ys, (all, i) -> all + i) as v from values (array(32, 97)) as t(ys); + +set spark.sql.ansi.enabled=true; +select transform(ys, all -> all * all) as v from values (array(32, 97)) as t(ys); +select transform(ys, (all, i) -> all + i) as v from values (array(32, 97)) as t(ys); +set spark.sql.ansi.enabled=false; diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index d418b887924f..0b78076588c1 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 29 +-- Number of queries: 33 -- !query 0 @@ -44,233 +44,273 @@ struct> -- !query 4 -select transform(ys, left -> left * left) as v from nested +select transform(ys, (y, i) -> y + i) as v from nested -- !query 4 schema struct> -- !query 4 output -[1024,9409] -[144] -[5929,5776] - - --- !query 5 -select transform(ys, (y, i) -> y + i) as v from nested --- !query 5 schema -struct> --- !query 5 output [12] [32,98] [77,-75] --- !query 6 -select transform(ys, (cost, i) -> cost + i) as v from nested --- !query 6 schema -struct> --- !query 6 output -[12] -[32,98] -[77,-75] - - --- !query 7 +-- !query 5 select transform(zs, z -> concat(ys, z)) as v from nested --- !query 7 schema +-- !query 5 schema struct>> --- !query 7 output +-- !query 5 output [[12,17]] [[32,97,12,99],[32,97,123,42],[32,97,1]] [[77,-76,6,96,65],[77,-76,-1,-2]] --- !query 8 +-- !query 6 select transform(ys, 0) as v from nested --- !query 8 schema +-- !query 6 schema struct> --- !query 8 output +-- !query 6 output [0,0] [0,0] [0] --- !query 9 +-- !query 7 select transform(cast(null as array), x -> x + 1) as v --- !query 9 schema +-- !query 7 schema struct> --- !query 9 output +-- !query 7 output NULL --- !query 10 +-- !query 8 select filter(ys, y -> y > 30) as v from nested --- !query 10 schema +-- !query 8 schema struct> --- !query 10 output +-- !query 8 output [32,97] [77] [] --- !query 11 +-- !query 9 select filter(cast(null as array), y -> true) as v --- !query 11 schema +-- !query 9 schema struct> --- !query 11 output +-- !query 9 output NULL --- !query 12 +-- !query 10 select transform(zs, z -> filter(z, zz -> zz > 50)) as v from nested --- !query 12 schema +-- !query 10 schema struct>> --- !query 12 output +-- !query 10 output [[96,65],[]] [[99],[123],[]] [[]] --- !query 13 +-- !query 11 select aggregate(ys, 0, (y, a) -> y + a + x) as v from nested --- !query 13 schema +-- !query 11 schema struct --- !query 13 output +-- !query 11 output 131 15 5 --- !query 14 +-- !query 12 select aggregate(ys, (0 as sum, 0 as n), (acc, x) -> (acc.sum + x, acc.n + 1), acc -> acc.sum / acc.n) as v from nested --- !query 14 schema +-- !query 12 schema struct --- !query 14 output +-- !query 12 output 0.5 12.0 64.5 --- !query 15 +-- !query 13 select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as v from nested --- !query 15 schema +-- !query 13 schema struct> --- !query 15 output +-- !query 13 output [1010880,8] [17] [4752,20664,1] --- !query 16 +-- !query 14 select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) as v --- !query 16 schema +-- !query 14 schema struct --- !query 16 output +-- !query 14 output NULL --- !query 17 +-- !query 15 select exists(ys, y -> y > 30) as v from nested --- !query 17 schema +-- !query 15 schema struct --- !query 17 output +-- !query 15 output false true true --- !query 18 +-- !query 16 select exists(cast(null as array), y -> y > 30) as v --- !query 18 schema +-- !query 16 schema struct --- !query 18 output +-- !query 16 output NULL --- !query 19 +-- !query 17 select zip_with(ys, zs, (a, b) -> a + size(b)) as v from nested --- !query 19 schema +-- !query 17 schema struct> --- !query 19 output +-- !query 17 output [13] [34,99,null] [80,-74] --- !query 20 +-- !query 18 select zip_with(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)) as v --- !query 20 schema +-- !query 18 schema struct> --- !query 20 output +-- !query 18 output ["ad","be","cf"] --- !query 21 +-- !query 19 select zip_with(array('a'), array('d', null, 'f'), (x, y) -> coalesce(x, y)) as v --- !query 21 schema +-- !query 19 schema struct> --- !query 21 output +-- !query 19 output ["a",null,"f"] --- !query 22 +-- !query 20 create or replace temporary view nested as values (1, map(1, 1, 2, 2, 3, 3)), (2, map(4, 4, 5, 5, 6, 6)) as t(x, ys) --- !query 22 schema +-- !query 20 schema struct<> --- !query 22 output +-- !query 20 output --- !query 23 +-- !query 21 select transform_keys(ys, (k, v) -> k) as v from nested --- !query 23 schema +-- !query 21 schema struct> --- !query 23 output +-- !query 21 output {1:1,2:2,3:3} {4:4,5:5,6:6} --- !query 24 +-- !query 22 select transform_keys(ys, (k, v) -> k + 1) as v from nested --- !query 24 schema +-- !query 22 schema struct> --- !query 24 output +-- !query 22 output {2:1,3:2,4:3} {5:4,6:5,7:6} --- !query 25 +-- !query 23 select transform_keys(ys, (k, v) -> k + v) as v from nested --- !query 25 schema +-- !query 23 schema struct> --- !query 25 output +-- !query 23 output {10:5,12:6,8:4} {2:1,4:2,6:3} --- !query 26 +-- !query 24 select transform_values(ys, (k, v) -> v) as v from nested --- !query 26 schema +-- !query 24 schema struct> --- !query 26 output +-- !query 24 output {1:1,2:2,3:3} {4:4,5:5,6:6} --- !query 27 +-- !query 25 select transform_values(ys, (k, v) -> v + 1) as v from nested --- !query 27 schema +-- !query 25 schema struct> --- !query 27 output +-- !query 25 output {1:2,2:3,3:4} {4:5,5:6,6:7} --- !query 28 +-- !query 26 select transform_values(ys, (k, v) -> k + v) as v from nested --- !query 28 schema +-- !query 26 schema struct> --- !query 28 output +-- !query 26 output {1:2,2:4,3:6} {4:8,5:10,6:12} + + +-- !query 27 +select transform(ys, all -> all * all) as v from values (array(32, 97)) as t(ys) +-- !query 27 schema +struct> +-- !query 27 output +[1024,9409] + + +-- !query 28 +select transform(ys, (all, i) -> all + i) as v from values (array(32, 97)) as t(ys) +-- !query 28 schema +struct> +-- !query 28 output +[32,98] + + +-- !query 29 +set spark.sql.ansi.enabled=true +-- !query 29 schema +struct +-- !query 29 output +spark.sql.ansi.enabled true + + +-- !query 30 +select transform(ys, all -> all * all) as v from values (array(32, 97)) as t(ys) +-- !query 30 schema +struct<> +-- !query 30 output +org.apache.spark.sql.catalyst.parser.ParseException + +no viable alternative at input 'all'(line 1, pos 21) + +== SQL == +select transform(ys, all -> all * all) as v from values (array(32, 97)) as t(ys) +---------------------^^^ + + +-- !query 31 +select transform(ys, (all, i) -> all + i) as v from values (array(32, 97)) as t(ys) +-- !query 31 schema +struct<> +-- !query 31 output +org.apache.spark.sql.catalyst.parser.ParseException + +no viable alternative at input 'all'(line 1, pos 22) + +== SQL == +select transform(ys, (all, i) -> all + i) as v from values (array(32, 97)) as t(ys) +----------------------^^^ + + +-- !query 32 +set spark.sql.ansi.enabled=false +-- !query 32 schema +struct +-- !query 32 output +spark.sql.ansi.enabled false