diff --git a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/keygen/TimestampBasedAvroKeyGenerator.java b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/keygen/TimestampBasedAvroKeyGenerator.java index bce7e24c57a5f..0c788188a9c02 100644 --- a/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/keygen/TimestampBasedAvroKeyGenerator.java +++ b/hudi-client/hudi-client-common/src/main/java/org/apache/hudi/keygen/TimestampBasedAvroKeyGenerator.java @@ -65,6 +65,29 @@ public enum TimestampType implements Serializable { protected final boolean encodePartitionPath; + /** + * Supported configs. + */ + public static class Config { + + // One value from TimestampType above + public static final String TIMESTAMP_TYPE_FIELD_PROP = "hoodie.deltastreamer.keygen.timebased.timestamp.type"; + public static final String INPUT_TIME_UNIT = + "hoodie.deltastreamer.keygen.timebased.timestamp.scalar.time.unit"; + //This prop can now accept list of input date formats. + public static final String TIMESTAMP_INPUT_DATE_FORMAT_PROP = + "hoodie.deltastreamer.keygen.timebased.input.dateformat"; + public static final String TIMESTAMP_INPUT_DATE_FORMAT_LIST_DELIMITER_REGEX_PROP = "hoodie.deltastreamer.keygen.timebased.input.dateformat.list.delimiter.regex"; + public static final String TIMESTAMP_INPUT_TIMEZONE_FORMAT_PROP = "hoodie.deltastreamer.keygen.timebased.input.timezone"; + public static final String TIMESTAMP_OUTPUT_DATE_FORMAT_PROP = + "hoodie.deltastreamer.keygen.timebased.output.dateformat"; + //still keeping this prop for backward compatibility so that functionality for existing users does not break. + public static final String TIMESTAMP_TIMEZONE_FORMAT_PROP = + "hoodie.deltastreamer.keygen.timebased.timezone"; + public static final String TIMESTAMP_OUTPUT_TIMEZONE_FORMAT_PROP = "hoodie.deltastreamer.keygen.timebased.output.timezone"; + static final String DATE_TIME_PARSER_PROP = "hoodie.deltastreamer.keygen.datetime.parser.class"; + } + public TimestampBasedAvroKeyGenerator(TypedProperties config) throws IOException { this(config, config.getString(KeyGeneratorOptions.RECORDKEY_FIELD_NAME.key()), config.getString(KeyGeneratorOptions.PARTITIONPATH_FIELD_NAME.key())); diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala index 4cc1bca1d614d..8e4c6376dd851 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/hudi/SparkAdapterSupport.scala @@ -27,7 +27,11 @@ import org.apache.spark.sql.hudi.SparkAdapter trait SparkAdapterSupport { lazy val sparkAdapter: SparkAdapter = { - val adapterClass = if (HoodieSparkUtils.isSpark3) { + val adapterClass = if (HoodieSparkUtils.isSpark3_1) { + "org.apache.spark.sql.adapter.Spark3_1Adapter" + } else if (HoodieSparkUtils.isSpark3_2) { + "org.apache.spark.sql.adapter.Spark3_2Adapter" + } else if (HoodieSparkUtils.isSpark3) { "org.apache.spark.sql.adapter.Spark3Adapter" } else { "org.apache.spark.sql.adapter.Spark2Adapter" diff --git a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala index 20b4d3cc1be73..9509e3fef4424 100644 --- a/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala +++ b/hudi-client/hudi-spark-client/src/main/scala/org/apache/spark/sql/hudi/SparkAdapter.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.hudi -import org.apache.hudi.HoodieSparkUtils.sparkAdapter import org.apache.hudi.client.utils.SparkRowSerDe import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable diff --git a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala index 12bbd64851001..17f3c1bc4f217 100644 --- a/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala +++ b/hudi-spark-datasource/hudi-spark/src/main/scala/org/apache/spark/sql/hudi/analysis/HoodieAnalysis.scala @@ -17,17 +17,19 @@ package org.apache.spark.sql.hudi.analysis +import org.apache.hudi.DataSourceReadOptions import org.apache.hudi.DataSourceWriteOptions.MOR_TABLE_TYPE_OPT_VAL import org.apache.hudi.common.model.HoodieRecord import org.apache.hudi.common.util.ReflectionUtils import org.apache.hudi.{HoodieSparkUtils, SparkAdapterSupport} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.catalog.{CatalogUtils, HoodieCatalogTable} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} import org.apache.spark.sql.hudi.HoodieSqlCommonUtils.{getTableIdentifier, removeMetaFields} import org.apache.spark.sql.hudi.HoodieSqlUtils._ import org.apache.spark.sql.hudi.command._ @@ -110,6 +112,7 @@ case class HoodieAnalysis(sparkSession: SparkSession) extends Rule[LogicalPlan] case _ => l } + // Convert to CreateHoodieTableAsSelectCommand case CreateTable(table, mode, Some(query)) if query.resolved && sparkAdapter.isHoodieTable(table) => @@ -133,7 +136,7 @@ case class HoodieAnalysis(sparkSession: SparkSession) extends Rule[LogicalPlan] // Convert to CompactionShowHoodiePathCommand case CompactionShowOnPath(path, limit) => CompactionShowHoodiePathCommand(path, limit) - case _=> plan + case _ => plan } } } @@ -350,6 +353,35 @@ case class HoodieResolveReferences(sparkSession: SparkSession) extends Rule[Logi l } + case TimeTravelRelation(plan: UnresolvedRelation, timestamp, version) => + // TODO: How to use version to perform time travel? + if (timestamp.isEmpty && version.nonEmpty) { + throw new AnalysisException( + "version expression is not support for time travel") + } + + val tableIdentifier = sparkAdapter.toTableIdentifier(plan) + if (sparkAdapter.isHoodieTable(tableIdentifier, sparkSession)) { + val hoodieCatalogTable = HoodieCatalogTable(sparkSession, tableIdentifier) + val table = hoodieCatalogTable.table + val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_)) + val instantOption = Map( + DataSourceReadOptions.TIME_TRAVEL_AS_OF_INSTANT.key -> timestamp.get.toString()) + val dataSource = + DataSource( + sparkSession, + userSpecifiedSchema = if (table.schema.isEmpty) None else Some(table.schema), + partitionColumns = table.partitionColumnNames, + bucketSpec = table.bucketSpec, + className = table.provider.get, + options = table.storage.properties ++ pathOption ++ instantOption, + catalogTable = Some(table)) + + LogicalRelation(dataSource.resolveRelation(checkFilesExist = false), table) + } else { + plan + } + case p => p } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestTimeTravelParser.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestTimeTravelParser.scala new file mode 100644 index 0000000000000..4ad6674a300eb --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestTimeTravelParser.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.logical.{Project, TimeTravelRelation} + +class TestTimeTravelParser extends TestHoodieSqlBase { + private val parser = spark.sessionState.sqlParser + + test("time travel of timestamp") { + val timeTravelPlan1 = parser.parsePlan("SELECT * FROM A.B " + + "TIMESTAMP AS OF '2019-01-29 00:37:58'") + + assertResult(Project(Seq(UnresolvedStar(None)), + TimeTravelRelation( + UnresolvedRelation(new TableIdentifier("B", Option.apply("A"))), + Some(Literal("2019-01-29 00:37:58")), + None))) { + timeTravelPlan1 + } + + val timeTravelPlan2 = parser.parsePlan("SELECT * FROM A.B " + + "TIMESTAMP AS OF 1643119574") + + assertResult(Project(Seq(UnresolvedStar(None)), + TimeTravelRelation( + UnresolvedRelation(new TableIdentifier("B", Option.apply("A"))), + Some(Literal(1643119574)), + None))) { + timeTravelPlan2 + } + } + + test("time travel of version") { + val timeTravelPlan1 = parser.parsePlan("SELECT * FROM A.B " + + "VERSION AS OF 'Snapshot123456789'") + + assertResult(Project(Seq(UnresolvedStar(None)), + TimeTravelRelation( + UnresolvedRelation(new TableIdentifier("B", Option.apply("A"))), + None, + Some("Snapshot123456789")))) { + timeTravelPlan1 + } + + val timeTravelPlan2 = parser.parsePlan("SELECT * FROM A.B " + + "VERSION AS OF 'Snapshot01'") + + assertResult(Project(Seq(UnresolvedStar(None)), + TimeTravelRelation( + UnresolvedRelation(new TableIdentifier("B", Option.apply("A"))), + None, + Some("Snapshot01")))) { + timeTravelPlan2 + } + } +} diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestTimeTravelTable.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestTimeTravelTable.scala new file mode 100644 index 0000000000000..eaebd2d96f084 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestTimeTravelTable.scala @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi + +import org.apache.hudi.common.table.HoodieTableMetaClient + +class TestTimeTravelTable extends TestHoodieSqlBase { + test("Test Insert and Update with time travel") { + withTempDir { tmp => + val tableName1 = generateTableName + spark.sql( + s""" + |create table $tableName1 ( + | id int, + | name string, + | price double, + | ts long + |) using hudi + | tblproperties ( + | type = 'cow', + | primaryKey = 'id', + | preCombineField = 'ts' + | ) + | location '${tmp.getCanonicalPath}/$tableName1' + """.stripMargin) + + spark.sql(s"insert into $tableName1 values(1, 'a1', 10, 1000)") + + val metaClient1 = HoodieTableMetaClient.builder() + .setBasePath(s"${tmp.getCanonicalPath}/$tableName1") + .setConf(spark.sessionState.newHadoopConf()) + .build() + + val instant1 = metaClient1.getActiveTimeline.getAllCommitsTimeline + .lastInstant().get().getTimestamp + + spark.sql(s"insert into $tableName1 values(1, 'a2', 20, 2000)") + + checkAnswer(s"select id, name, price, ts from $tableName1")( + Seq(1, "a2", 20.0, 2000) + ) + + // time travel from instant1 + checkAnswer( + s"select id, name, price, ts from $tableName1 TIMESTAMP AS OF '$instant1'")( + Seq(1, "a1", 10.0, 1000) + ) + } + } + + test("Test Two Table's Union Join with time travel") { + withTempDir { tmp => + Seq("cow", "mor").foreach { tableType => + val tableName = generateTableName + + val basePath = tmp.getCanonicalPath + val tableName1 = tableName + "_1" + val tableName2 = tableName + "_2" + val path1 = s"$basePath/$tableName1" + val path2 = s"$basePath/$tableName2" + + spark.sql( + s""" + |create table $tableName1 ( + | id int, + | name string, + | price double, + | ts long + |) using hudi + | tblproperties ( + | type = '$tableType', + | primaryKey = 'id', + | preCombineField = 'ts' + | ) + | location '$path1' + """.stripMargin) + + spark.sql( + s""" + |create table $tableName2 ( + | id int, + | name string, + | price double, + | ts long + |) using hudi + | tblproperties ( + | type = '$tableType', + | primaryKey = 'id', + | preCombineField = 'ts' + | ) + | location '$path2' + """.stripMargin) + + spark.sql(s"insert into $tableName1 values(1, 'a1', 10, 1000)") + spark.sql(s"insert into $tableName1 values(2, 'a2', 20, 1000)") + + checkAnswer(s"select id, name, price, ts from $tableName1")( + Seq(1, "a1", 10.0, 1000), + Seq(2, "a2", 20.0, 1000) + ) + + checkAnswer(s"select id, name, price, ts from $tableName1")( + Seq(1, "a1", 10.0, 1000), + Seq(2, "a2", 20.0, 1000) + ) + + spark.sql(s"insert into $tableName2 values(3, 'a3', 10, 1000)") + spark.sql(s"insert into $tableName2 values(4, 'a4', 20, 1000)") + + checkAnswer(s"select id, name, price, ts from $tableName2")( + Seq(3, "a3", 10.0, 1000), + Seq(4, "a4", 20.0, 1000) + ) + + val metaClient1 = HoodieTableMetaClient.builder() + .setBasePath(path1) + .setConf(spark.sessionState.newHadoopConf()) + .build() + + val metaClient2 = HoodieTableMetaClient.builder() + .setBasePath(path2) + .setConf(spark.sessionState.newHadoopConf()) + .build() + + val instant1 = metaClient1.getActiveTimeline.getAllCommitsTimeline + .lastInstant().get().getTimestamp + + val instant2 = metaClient2.getActiveTimeline.getAllCommitsTimeline + .lastInstant().get().getTimestamp + + val sql = + s""" + |select id, name, price, ts from $tableName1 TIMESTAMP AS OF '$instant1' where id=1 + |union + |select id, name, price, ts from $tableName2 TIMESTAMP AS OF '$instant2' where id>1 + |""".stripMargin + + checkAnswer(sql)( + Seq(1, "a1", 10.0, 1000), + Seq(3, "a3", 10.0, 1000), + Seq(4, "a4", 20.0, 1000) + ) + } + } + } + + test("Test Insert Into with time travel") { + withTempDir { tmp => + // Create Non-Partitioned table + val tableName1 = generateTableName + spark.sql( + s""" + |create table $tableName1 ( + | id int, + | name string, + | price double, + | ts long + |) using hudi + | tblproperties ( + | type = 'cow', + | primaryKey = 'id', + | preCombineField = 'ts' + | ) + | location '${tmp.getCanonicalPath}/$tableName1' + """.stripMargin) + + spark.sql(s"insert into $tableName1 values(1, 'a1', 10, 1000)") + + val metaClient1 = HoodieTableMetaClient.builder() + .setBasePath(s"${tmp.getCanonicalPath}/$tableName1") + .setConf(spark.sessionState.newHadoopConf()) + .build() + + val instant1 = metaClient1.getActiveTimeline.getAllCommitsTimeline + .lastInstant().get().getTimestamp + + + val tableName2 = generateTableName + // Create a partitioned table + spark.sql( + s""" + |create table $tableName2 ( + | id int, + | name string, + | price double, + | ts long, + | dt string + |) using hudi + | tblproperties (primaryKey = 'id') + | partitioned by (dt) + | location '${tmp.getCanonicalPath}/$tableName2' + """.stripMargin) + + // Insert into dynamic partition + spark.sql( + s""" + | insert into $tableName2 + + | select id, name, price, ts, '2022- + + | from $tableName1 TIMESTAMP AS OF '$instant1' + """. + stripMargin) + checkAnswer(s"select id, name, price, ts, dt from $tableName2")( + Seq(1, "a1", 10.0, 1000, "2022-02-14") + ) + + // Ins static partition + spark.sql( + s""" + | insert into $tableName2 partition(dt = '2022-02-15') + | select 2 as id, 'a2' as name, price, ts from $tableName1 + TIMESTAMP AS OF '$instant1' + """.stripMargin) + checkAnswer( + s"select id, name, price, ts, dt from $tableName2")( + Seq(1, "a1", 10.0, 1000, "2022-02-14"), + Seq(2, "a2", 10.0, 1000, "2022-02-15") + ) + } + } +} diff --git a/hudi-spark-datasource/hudi-spark2/src/main/antlr4/imports/SqlBase.g4 b/hudi-spark-datasource/hudi-spark2/src/main/antlr4/imports/SqlBase.g4 index 2add2b030f538..a63457d842e5e 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/antlr4/imports/SqlBase.g4 +++ b/hudi-spark-datasource/hudi-spark2/src/main/antlr4/imports/SqlBase.g4 @@ -414,6 +414,11 @@ fromClause : FROM relation (',' relation)* lateralView* pivotClause? ; +temporalClause + : FOR? (SYSTEM_TIME | TIMESTAMP) AS OF timestamp=valueExpression + | FOR? (SYSTEM_VERSION | VERSION) AS OF version=(INTEGER_VALUE | STRING) + ; + aggregation : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* ( WITH kind=ROLLUP @@ -510,7 +515,8 @@ identifierComment ; relationPrimary - : tableIdentifier sample? tableAlias #tableName + : tableIdentifier temporalClause? + sample? tableAlias #tableName | '(' queryNoWith ')' sample? tableAlias #aliasedQuery | '(' relation ')' sample? tableAlias #aliasedRelation | inlineTable #inlineTableDefault2 @@ -778,6 +784,7 @@ nonReserved | DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT | DIRECTORY | BOTH | LEADING | TRAILING + | SYSTEM_VERSION | VERSION | SYSTEM_TIME | TIMESTAMP ; SELECT: 'SELECT'; @@ -1015,6 +1022,11 @@ ANTI: 'ANTI'; LOCAL: 'LOCAL'; INPATH: 'INPATH'; +SYSTEM_VERSION: 'SYSTEM_VERSION'; +VERSION: 'VERSION'; +SYSTEM_TIME: 'SYSTEM_TIME'; +TIMESTAMP: 'TIMESTAMP'; + STRING : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' | '"' ( ~('"'|'\\') | ('\\' .) )* '"' diff --git a/hudi-spark-datasource/hudi-spark2/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 b/hudi-spark-datasource/hudi-spark2/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 index 6544d936243e9..a9fa73beef7bd 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 +++ b/hudi-spark-datasource/hudi-spark2/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 @@ -27,6 +27,32 @@ statement : mergeInto #mergeIntoTable | updateTableStmt #updateTable | deleteTableStmt #deleteTable + | query #queryStatement + | createTableHeader ('(' colTypeList ')')? tableProvider + ((OPTIONS options=tablePropertyList) | + (PARTITIONED BY partitionColumnNames=identifierList) | + bucketSpec | + locationSpec | + (COMMENT comment=STRING) | + (TBLPROPERTIES tableProps=tablePropertyList))* + (AS? query)? #createTable + | createTableHeader ('(' columns=colTypeList ')')? + ((COMMENT comment=STRING) | + (PARTITIONED BY '(' partitionColumns=colTypeList ')') | + bucketSpec | + skewSpec | + rowFormat | + createFileFormat | + locationSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* + (AS? query)? #createHiveTable + | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? + VIEW (IF NOT EXISTS)? tableIdentifier + identifierCommentList? (COMMENT STRING)? + (PARTITIONED ON identifierList)? + (TBLPROPERTIES tablePropertyList)? AS query #createView + | ALTER VIEW tableIdentifier AS? query #alterViewQuery + | CACHE LAZY? TABLE tableIdentifier (AS? query)? #cacheTable | .*? #passThrough ; @@ -87,6 +113,7 @@ assignmentList assignment : key=qualifiedName EQ value=expression ; + qualifiedNameList : qualifiedName (',' qualifiedName)* ; diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeTravelRelation.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeTravelRelation.scala new file mode 100644 index 0000000000000..67f942881c39d --- /dev/null +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeTravelRelation.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} + +case class TimeTravelRelation( + table: LogicalPlan, + timestamp: Option[Expression], + version: Option[String]) extends Command { + override def children: Seq[LogicalPlan] = Seq.empty + + override def output: Seq[Attribute] = Nil + + override lazy val resolved: Boolean = false + + def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = this +} diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/HoodieSpark2ExtendedSqlAstBuilder.scala b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/HoodieSpark2ExtendedSqlAstBuilder.scala index bbc9014fe804f..66ac10f503b4a 100644 --- a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/HoodieSpark2ExtendedSqlAstBuilder.scala +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/HoodieSpark2ExtendedSqlAstBuilder.scala @@ -16,30 +16,56 @@ */ package org.apache.spark.sql.hudi.parser -import org.antlr.v4.runtime.tree.ParseTree -import org.apache.hudi.spark.sql.parser.HoodieSqlBaseBaseVisitor +import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} +import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser._ +import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseVisitor, HoodieSqlBaseParser} import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} +import org.apache.spark.sql.catalyst.catalog.CatalogStorageFormat import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface, ParserUtils} +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.hudi.parser.unsafe.types.CalendarInterval import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.util.random.RandomSampler +import java.sql.{Date, Timestamp} +import java.util.Locale +import javax.xml.bind.DatatypeConverter import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer /** * The AstBuilder for HoodieSqlParser to parser the AST tree to Logical Plan. * Here we only do the parser for the extended sql syntax. e.g MergeInto. For * other sql syntax we use the delegate sql parser which is the SparkSqlParser. */ -class HoodieSpark2ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterface) extends HoodieSqlBaseBaseVisitor[AnyRef] with Logging { +class HoodieSpark2ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterface) + extends HoodieSqlBaseBaseVisitor[AnyRef] with Logging { import ParserUtils._ - override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { - ctx.statement().accept(this).asInstanceOf[LogicalPlan] + protected def typedVisit[T](ctx: ParseTree): T = { + ctx.accept(this).asInstanceOf[T] + } + + /** + * Override the default behavior for all visit methods. This will only return a non-null result + * when the context has only one child. This is done because there is no generic method to + * combine the results of the context children. In all other cases null is returned. + */ + override def visitChildren(node: RuleNode): AnyRef = { + if (node.getChildCount == 1) { + node.getChild(0).accept(this) + } else { + null + } } override def visitMergeIntoTable (ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) { @@ -141,6 +167,13 @@ class HoodieSpark2ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterface } } + /** + * Parse a qualified name to a multipart name. + */ + override def visitQualifiedName(ctx: QualifiedNameContext): Seq[String] = withOrigin(ctx) { + ctx.identifier.asScala.map(_.getText).toSeq + } + override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) { val updateStmt = ctx.updateTableStmt() val table = UnresolvedRelation(visitTableIdentifier(updateStmt.tableIdentifier())) @@ -170,19 +203,9 @@ class HoodieSpark2ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterface } /** - * Parse the expression tree to spark sql Expression. - * Here we use the SparkSqlParser to do the parse. - */ - private def expression(tree: ParseTree): Expression = { - val expressionText = treeToString(tree) - delegate.parseExpression(expressionText) - } - - // ============== The following code is fork from org.apache.spark.sql.catalyst.parser.AstBuilder - /** - * If aliases specified in a FROM clause, create a subquery alias ([[SubqueryAlias]]) and - * column aliases for a [[LogicalPlan]]. - */ + * If aliases specified in a FROM clause, create a subquery alias ([[SubqueryAlias]]) and + * column aliases for a [[LogicalPlan]]. + */ protected def mayApplyAliasPlan(tableAlias: TableAliasContext, plan: LogicalPlan): LogicalPlan = { if (tableAlias.strictIdentifier != null) { val alias = tableAlias.strictIdentifier.getText @@ -196,35 +219,1798 @@ class HoodieSpark2ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterface plan } } + /** - * Parse a qualified name to a multipart name. - */ - override def visitQualifiedName(ctx: QualifiedNameContext): Seq[String] = withOrigin(ctx) { - ctx.identifier.asScala.map(_.getText) + * Parse the expression tree to spark sql Expression. + * Here we use the SparkSqlParser to do the parse. + */ + private def expression(tree: ParseTree): Expression = { + val expressionText = treeToString(tree) + delegate.parseExpression(expressionText) } /** - * Create a Sequence of Strings for a parenthesis enclosed alias list. - */ - override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) { - visitIdentifierSeq(ctx.identifierSeq) + * Create an aliased table reference. This is typically used in FROM clauses. + */ + override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { + val tableId = visitTableIdentifier(ctx.tableIdentifier()) + val relation = UnresolvedRelation(tableId) + val table = mayApplyAliasPlan( + ctx.tableAlias, relation.optionalMap(ctx.temporalClause)(withTimeTravel)) + table.optionalMap(ctx.sample)(withSample) } - /** - * Create a Sequence of Strings for an identifier list. - */ - override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) { - ctx.identifier.asScala.map(_.getText) + private def withTimeTravel( + ctx: TemporalClauseContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val v = ctx.version + val version = if (ctx.INTEGER_VALUE != null) { + Some(v.getText) + } else { + Option(v).map(string) + } + val timestamp = Option(ctx.timestamp).map(expression) + if (timestamp.exists(_.references.nonEmpty)) { + throw new ParseException( + "timestamp expression cannot refer to any columns", ctx.timestamp) + } + if (timestamp.exists(e => SubqueryExpression.hasSubquery(e))) { + throw new ParseException( + "timestamp expression cannot contain subqueries", ctx.timestamp) + } + TimeTravelRelation(plan, timestamp, version) + } + + // ============== The following code is fork from org.apache.spark.sql.catalyst.parser.AstBuilder + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { + visit(ctx.statement).asInstanceOf[LogicalPlan] + } + + override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) { + visitNamedExpression(ctx.namedExpression) + } + + override def visitSingleTableIdentifier( + ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) { + visitTableIdentifier(ctx.tableIdentifier) + } + + override def visitSingleFunctionIdentifier( + ctx: SingleFunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) { + visitFunctionIdentifier(ctx.functionIdentifier) + } + + override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { + visitSparkDataType(ctx.dataType) + } + + override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = { + withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList))) } /* ******************************************************************************************** - * Table Identifier parsing + * Plan parsing * ******************************************************************************************** */ + protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree) + /** - * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern. - */ - override def visitTableIdentifier(ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) { - TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) + * Create a top-level plan with Common Table Expressions. + */ + override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) { + val query = plan(ctx.queryNoWith) + + // Apply CTEs + query.optional(ctx.ctes) { + val ctes = ctx.ctes.namedQuery.asScala.map { nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) + } + // Check for duplicate names. + checkDuplicateKeys(ctes, ctx) + With(query, ctes) + } + } + + /** + * Create a named logical plan. + * + * This is only used for Common Table Expressions. + */ + override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { + SubqueryAlias(ctx.name.getText, plan(ctx.query)) + } + + /** + * Create a logical plan which allows for multiple inserts using one 'from' statement. These + * queries have the following SQL form: + * {{{ + * [WITH cte...]? + * FROM src + * [INSERT INTO tbl1 SELECT *]+ + * }}} + * For example: + * {{{ + * FROM db.tbl1 A + * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5 + * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12 + * }}} + * This (Hive) feature cannot be combined with set-operators. + */ + override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + + // Build the insert clauses. + val inserts = ctx.multiInsertQueryBody.asScala.map { + body => + validate(body.querySpecification.fromClause == null, + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", + body) + + withQuerySpecification(body.querySpecification, from). + // Add organization statements. + optionalMap(body.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(body.insertInto())(withInsertInto) + } + + // If there are multiple INSERTS just UNION them together into one query. + inserts match { + case Seq(query) => query + case queries => Union(queries) + } + } + + /** + * Create a logical plan for a regular (single-insert) query. + */ + override def visitSingleInsertQuery( + ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryTerm). + // Add organization statements. + optionalMap(ctx.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(ctx.insertInto())(withInsertInto) + } + + /** + * Parameters used for writing query to a table: + * (tableIdentifier, partitionKeys, exists). + */ + type InsertTableParams = (TableIdentifier, Map[String, Option[String]], Boolean) + + /** + * Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider). + */ + type InsertDirParams = (Boolean, CatalogStorageFormat, Option[String]) + + /** + * Add an + * {{{ + * INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]? + * INSERT INTO [TABLE] tableIdentifier [partitionSpec] + * INSERT OVERWRITE [LOCAL] DIRECTORY STRING [rowFormat] [createFileFormat] + * INSERT OVERWRITE [LOCAL] DIRECTORY [STRING] tableProvider [OPTIONS tablePropertyList] + * }}} + * operation to logical plan + */ + private def withInsertInto( + ctx: InsertIntoContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + ctx match { + case table: InsertIntoTableContext => + val (tableIdent, partitionKeys, exists) = visitInsertIntoTable(table) + InsertIntoTable(UnresolvedRelation(tableIdent), partitionKeys, query, false, exists) + case table: InsertOverwriteTableContext => + val (tableIdent, partitionKeys, exists) = visitInsertOverwriteTable(table) + InsertIntoTable(UnresolvedRelation(tableIdent), partitionKeys, query, true, exists) + case dir: InsertOverwriteDirContext => + val (isLocal, storage, provider) = visitInsertOverwriteDir(dir) + InsertIntoDir(isLocal, storage, provider, query, overwrite = true) + case hiveDir: InsertOverwriteHiveDirContext => + val (isLocal, storage, provider) = visitInsertOverwriteHiveDir(hiveDir) + InsertIntoDir(isLocal, storage, provider, query, overwrite = true) + case _ => + throw new ParseException("Invalid InsertIntoContext", ctx) + } + } + + /** + * Add an INSERT INTO TABLE operation to the logical plan. + */ + override def visitInsertIntoTable( + ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) { + val tableIdent = visitTableIdentifier(ctx.tableIdentifier) + val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + + (tableIdent, partitionKeys, false) + } + + /** + * Add an INSERT OVERWRITE TABLE operation to the logical plan. + */ + override def visitInsertOverwriteTable( + ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) { + assert(ctx.OVERWRITE() != null) + val tableIdent = visitTableIdentifier(ctx.tableIdentifier) + val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + + val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty) + if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) { + throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " + + "partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx) + } + + (tableIdent, partitionKeys, ctx.EXISTS() != null) + } + + /** + * Write to a directory, returning a [[InsertIntoDir]] logical plan. + */ + override def visitInsertOverwriteDir( + ctx: InsertOverwriteDirContext): InsertDirParams = withOrigin(ctx) { + throw new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx) + } + + /** + * Write to a directory, returning a [[InsertIntoDir]] logical plan. + */ + override def visitInsertOverwriteHiveDir( + ctx: InsertOverwriteHiveDirContext): InsertDirParams = withOrigin(ctx) { + throw new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx) + } + + /** + * Create a partition specification map. + */ + override def visitPartitionSpec( + ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { + val parts = ctx.partitionVal.asScala.map { pVal => + val name = pVal.identifier.getText + val value = Option(pVal.constant).map(visitStringConstant) + name -> value + } + // Before calling `toMap`, we check duplicated keys to avoid silently ignore partition values + // in partition spec like PARTITION(a='1', b='2', a='3'). The real semantical check for + // partition columns will be done in analyzer. + checkDuplicateKeys(parts, ctx) + parts.toMap + } + + /** + * Create a partition specification map without optional values. + */ + protected def visitNonOptionalPartitionSpec( + ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { + visitPartitionSpec(ctx).map { + case (key, None) => throw new ParseException(s"Found an empty partition key '$key'.", ctx) + case (key, Some(value)) => key -> value + } + } + + /** + * Convert a constant of any type into a string. This is typically used in DDL commands, and its + * main purpose is to prevent slight differences due to back to back conversions i.e.: + * String -> Literal -> String. + */ + protected def visitStringConstant(ctx: ConstantContext): String = withOrigin(ctx) { + ctx match { + case s: StringLiteralContext => createString(s) + case o => o.getText + } + } + + /** + * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These + * clauses determine the shape (ordering/partitioning/rows) of the query result. + */ + private def withQueryResultClauses( + ctx: QueryOrganizationContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. + val withOrder = if ( + !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // ORDER BY ... + Sort(order.asScala.map(visitSortItem), global = true, query) + } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... + Sort(sort.asScala.map(visitSortItem), global = false, query) + } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // DISTRIBUTE BY ... + withRepartitionByExpression(ctx, expressionList(distributeBy), query) + } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... DISTRIBUTE BY ... + Sort( + sort.asScala.map(visitSortItem), + global = false, + withRepartitionByExpression(ctx, expressionList(distributeBy), query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { + // CLUSTER BY ... + val expressions = expressionList(clusterBy) + Sort( + expressions.map(SortOrder(_, Ascending)), + global = false, + withRepartitionByExpression(ctx, expressions, query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // [EMPTY] + query + } else { + throw new ParseException( + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx) + } + + // WINDOWS + val withWindow = withOrder.optionalMap(windows)(withWindows) + + // LIMIT + // - LIMIT ALL is the same as omitting the LIMIT clause + withWindow.optional(limit) { + Limit(typedVisit(limit), withWindow) + } + } + + /** + * Create a clause for DISTRIBUTE BY. + */ + protected def withRepartitionByExpression( + ctx: QueryOrganizationContext, + expressions: Seq[Expression], + query: LogicalPlan): LogicalPlan = { + throw new ParseException("DISTRIBUTE BY is not supported", ctx) + } + + /** + * Create a logical plan using a query specification. + */ + override def visitQuerySpecification( + ctx: QuerySpecificationContext): LogicalPlan = withOrigin(ctx) { + val from = OneRowRelation().optional(ctx.fromClause) { + visitFromClause(ctx.fromClause) + } + withQuerySpecification(ctx, from) + } + + /** + * Add a query specification to a logical plan. The query specification is the core of the logical + * plan, this is where sourcing (FROM clause), transforming (SELECT TRANSFORM/MAP/REDUCE), + * projection (SELECT), aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place. + * + * Note that query hints are ignored (both by the parser and the builder). + */ + private def withQuerySpecification( + ctx: QuerySpecificationContext, + relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // WHERE + def filter(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = { + Filter(expression(ctx), plan) + } + + def withHaving(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = { + // Note that we add a cast to non-predicate expressions. If the expression itself is + // already boolean, the optimizer will get rid of the unnecessary cast. + val predicate = expression(ctx) match { + case p: Predicate => p + case e => Cast(e, BooleanType) + } + Filter(predicate, plan) + } + + + // Expressions. + val expressions = Option(namedExpressionSeq).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + + // Create either a transform or a regular query. + val specType = Option(kind).map(_.getType).getOrElse(HoodieSqlBaseParser.SELECT) + specType match { + case HoodieSqlBaseParser.MAP | HoodieSqlBaseParser.REDUCE | HoodieSqlBaseParser.TRANSFORM => + // Transform + + // Add where. + val withFilter = relation.optionalMap(where)(filter) + + // Create the attributes. + val (attributes, schemaLess) = if (colTypeList != null) { + // Typed return columns. + (createSchema(colTypeList).toAttributes, false) + } else if (identifierSeq != null) { + // Untyped return columns. + val attrs = visitIdentifierSeq(identifierSeq).map { name => + AttributeReference(name, StringType, nullable = true)() + } + (attrs, false) + } else { + (Seq(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), true) + } + + // Create the transform. + ScriptTransformation( + expressions, + string(script), + attributes, + withFilter, + withScriptIOSchema( + ctx, inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess)) + + case HoodieSqlBaseParser.SELECT => + // Regular select + + // Add lateral views. + val withLateralView = ctx.lateralView.asScala.foldLeft(relation)(withGenerate) + + // Add where. + val withFilter = withLateralView.optionalMap(where)(filter) + + // Add aggregation or a project. + val namedExpressions = expressions.map { + case e: NamedExpression => e + case e: Expression => UnresolvedAlias(e) + } + + def createProject() = if (namedExpressions.nonEmpty) { + Project(namedExpressions, withFilter) + } else { + withFilter + } + + val withProject = if (aggregation == null && having != null) { + if (conf.getConf(SQLConf.LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE)) { + // If the legacy conf is set, treat HAVING without GROUP BY as WHERE. + withHaving(having, createProject()) + } else { + // According to SQL standard, HAVING without GROUP BY means global aggregate. + withHaving(having, Aggregate(Nil, namedExpressions, withFilter)) + } + } else if (aggregation != null) { + val aggregate = withAggregation(aggregation, namedExpressions, withFilter) + aggregate.optionalMap(having)(withHaving) + } else { + // When hitting this branch, `having` must be null. + createProject() + } + + // Distinct + val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) { + Distinct(withProject) + } else { + withProject + } + + // Window + val withWindow = withDistinct.optionalMap(windows)(withWindows) + + // Hint + hints.asScala.foldRight(withWindow)(withHints) + } + } + + /** + * Create a (Hive based) [[ScriptInputOutputSchema]]. + */ + protected def withScriptIOSchema( + ctx: QuerySpecificationContext, + inRowFormat: RowFormatContext, + recordWriter: Token, + outRowFormat: RowFormatContext, + recordReader: Token, + schemaLess: Boolean): ScriptInputOutputSchema = { + throw new ParseException("Script Transform is not supported", ctx) + } + + /** + * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma + * separated) relations here, these get converted into a single plan by condition-less inner join. + */ + override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { + val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) => + val right = plan(relation.relationPrimary) + val join = right.optionalMap(left)(Join(_, _, Inner, None)) + withJoinRelations(join, relation) + } + if (ctx.pivotClause() != null) { + if (!ctx.lateralView.isEmpty) { + throw new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx) + } + withPivot(ctx.pivotClause, from) + } else { + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } + } + + /** + * Connect two queries by a Set operator. + * + * Supported Set operators are: + * - UNION [ DISTINCT | ALL ] + * - EXCEPT [ DISTINCT | ALL ] + * - MINUS [ DISTINCT | ALL ] + * - INTERSECT [DISTINCT | ALL] + */ + override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { + val left = plan(ctx.left) + val right = plan(ctx.right) + val all = Option(ctx.setQuantifier()).exists(_.ALL != null) + ctx.operator.getType match { + case HoodieSqlBaseParser.UNION if all => + Union(left, right) + case HoodieSqlBaseParser.UNION => + Distinct(Union(left, right)) + case HoodieSqlBaseParser.INTERSECT if all => + Intersect(left, right, isAll = true) + case HoodieSqlBaseParser.INTERSECT => + Intersect(left, right, isAll = false) + case HoodieSqlBaseParser.EXCEPT if all => + Except(left, right, isAll = true) + case HoodieSqlBaseParser.EXCEPT => + Except(left, right, isAll = false) + case HoodieSqlBaseParser.SETMINUS if all => + Except(left, right, isAll = true) + case HoodieSqlBaseParser.SETMINUS => + Except(left, right, isAll = false) + } + } + + /** + * Add a [[WithWindowDefinition]] operator to a logical plan. + */ + private def withWindows( + ctx: WindowsContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Collect all window specifications defined in the WINDOW clause. + val baseWindowMap = ctx.namedWindow.asScala.map { + wCtx => + (wCtx.identifier.getText, typedVisit[WindowSpec](wCtx.windowSpec)) + }.toMap + + // Handle cases like + // window w1 as (partition by p_mfgr order by p_name + // range between 2 preceding and 2 following), + // w2 as w1 + val windowMapView = baseWindowMap.mapValues { + case WindowSpecReference(name) => + baseWindowMap.get(name) match { + case Some(spec: WindowSpecDefinition) => + spec + case Some(ref) => + throw new ParseException(s"Window reference '$name' is not a window specification", ctx) + case None => + throw new ParseException(s"Cannot resolve window reference '$name'", ctx) + } + case spec: WindowSpecDefinition => spec + } + + // Note that mapValues creates a view instead of materialized map. We force materialization by + // mapping over identity. + WithWindowDefinition(windowMapView.map(identity), query) + } + + /** + * Add an [[Aggregate]] or [[GroupingSets]] to a logical plan. + */ + private def withAggregation( + ctx: AggregationContext, + selectExpressions: Seq[NamedExpression], + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val groupByExpressions = expressionList(ctx.groupingExpressions) + + if (ctx.GROUPING != null) { + // GROUP BY .... GROUPING SETS (...) + val selectedGroupByExprs = + ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e))) + GroupingSets(selectedGroupByExprs, groupByExpressions, query, selectExpressions) + } else { + // GROUP BY .... (WITH CUBE | WITH ROLLUP)? + val mappedGroupByExpressions = if (ctx.CUBE != null) { + Seq(Cube(groupByExpressions)) + } else if (ctx.ROLLUP != null) { + Seq(Rollup(groupByExpressions)) + } else { + groupByExpressions + } + Aggregate(mappedGroupByExpressions, selectExpressions, query) + } + } + + /** + * Add [[UnresolvedHint]]s to a logical plan. + */ + private def withHints( + ctx: HintContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + var plan = query + ctx.hintStatements.asScala.reverse.foreach { case stmt => + plan = UnresolvedHint(stmt.hintName.getText, stmt.parameters.asScala.map(expression), plan) + } + plan + } + + /** + * Add a [[Pivot]] to a logical plan. + */ + private def withPivot( + ctx: PivotClauseContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val aggregates = Option(ctx.aggregates).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) { + UnresolvedAttribute.quoted(ctx.pivotColumn.identifier.getText) + } else { + CreateStruct( + ctx.pivotColumn.identifiers.asScala.map( + identifier => UnresolvedAttribute.quoted(identifier.getText))) + } + val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue) + Pivot(None, pivotColumn, pivotValues, aggregates, query) + } + + /** + * Create a Pivot column value with or without an alias. + */ + override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else { + e + } + } + + /** + * Add a [[Generate]] (Lateral View) to a logical plan. + */ + private def withGenerate( + query: LogicalPlan, + ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) { + val expressions = expressionList(ctx.expression) + Generate( + UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions), + unrequiredChildIndex = Nil, + outer = ctx.OUTER != null, + Some(ctx.tblName.getText.toLowerCase), + ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply), + query) + } + + /** + * Create a single relation referenced in a FROM clause. This method is used when a part of the + * join condition is nested, for example: + * {{{ + * select * from t1 join (t2 cross join t3) on col1 = col2 + * }}} + */ + override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) { + withJoinRelations(plan(ctx.relationPrimary), ctx) + } + + /** + * Join one more [[LogicalPlan]]s to the current logical plan. + */ + private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = { + ctx.joinRelation.asScala.foldLeft(base) { (left, join) => + withOrigin(join) { + val baseJoinType = join.joinType match { + case null => Inner + case jt if jt.CROSS != null => Cross + case jt if jt.FULL != null => FullOuter + case jt if jt.SEMI != null => LeftSemi + case jt if jt.ANTI != null => LeftAnti + case jt if jt.LEFT != null => LeftOuter + case jt if jt.RIGHT != null => RightOuter + case _ => Inner + } + + // Resolve the join type and join condition + val (joinType, condition) = Option(join.joinCriteria) match { + case Some(c) if c.USING != null => + (UsingJoin(baseJoinType, c.identifier.asScala.map(_.getText)), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case None if join.NATURAL != null => + if (baseJoinType == Cross) { + throw new ParseException("NATURAL CROSS JOIN is not supported", ctx) + } + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + Join(left, plan(join.right), joinType, condition) + } + } + } + + /** + * Add a [[Sample]] to a logical plan. + * + * This currently supports the following sampling methods: + * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. + * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages + * are defined as a number between 0 and 100. + * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction. + */ + private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Create a sampled plan if we need one. + def sample(fraction: Double): Sample = { + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + val eps = RandomSampler.roundingEpsilon + validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps, + s"Sampling fraction ($fraction) must be on interval [0, 1]", + ctx) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query) + } + + if (ctx.sampleMethod() == null) { + throw new ParseException("TABLESAMPLE does not accept empty inputs.", ctx) + } + + ctx.sampleMethod() match { + case ctx: SampleByRowsContext => + Limit(expression(ctx.expression), query) + + case ctx: SampleByPercentileContext => + val fraction = ctx.percentage.getText.toDouble + val sign = if (ctx.negativeSign == null) 1 else -1 + sample(sign * fraction / 100.0d) + + case ctx: SampleByBytesContext => + val bytesStr = ctx.bytes.getText + if (bytesStr.matches("[0-9]+[bBkKmMgG]")) { + throw new ParseException("TABLESAMPLE(byteLengthLiteral) is not supported", ctx) + } else { + throw new ParseException( + bytesStr + " is not a valid byte length literal, " + + "expected syntax: DIGIT+ ('B' | 'K' | 'M' | 'G')", ctx) + } + + case ctx: SampleByBucketContext if ctx.ON() != null => + if (ctx.identifier != null) { + throw new ParseException( + "TABLESAMPLE(BUCKET x OUT OF y ON colname) is not supported", ctx) + } else { + throw new ParseException( + "TABLESAMPLE(BUCKET x OUT OF y ON function) is not supported", ctx) + } + + case ctx: SampleByBucketContext => + sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) + } + } + + /** + * Create a logical plan for a sub-query. + */ + override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryNoWith) + } + + /** + * Create an un-aliased table reference. This is typically used for top-level table references, + * for example: + * {{{ + * INSERT INTO db.tbl2 + * TABLE db.tbl1 + * }}} + */ + override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) { + UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier)) + } + + /** + * Create a table-valued function call with arguments, e.g. range(1000) + */ + override def visitTableValuedFunction(ctx: TableValuedFunctionContext) + : LogicalPlan = withOrigin(ctx) { + val func = ctx.functionTable + val aliases = if (func.tableAlias.identifierList != null) { + visitIdentifierList(func.tableAlias.identifierList) + } else { + Seq.empty + } + + val tvf = UnresolvedTableValuedFunction( + func.identifier.getText, func.expression.asScala.map(expression), aliases) + tvf.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan) + } + + /** + * Create an inline table (a virtual table in Hive parlance). + */ + override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { + // Get the backing expressions. + val rows = ctx.expression.asScala.map { e => + expression(e) match { + // inline table comes in two styles: + // style 1: values (1), (2), (3) -- multiple columns are supported + // style 2: values 1, 2, 3 -- only a single column is supported here + case struct: CreateNamedStruct => struct.valExprs // style 1 + case child => Seq(child) // style 2 + } + } + + val aliases = if (ctx.tableAlias.identifierList != null) { + visitIdentifierList(ctx.tableAlias.identifierList) + } else { + Seq.tabulate(rows.head.size)(i => s"col${i + 1}") + } + + val table = UnresolvedInlineTable(aliases, rows) + table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a join relation. This is practically the same as + * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. We could add alias names for output columns, for example: + * {{{ + * SELECT a, b, c, d FROM (src1 s1 INNER JOIN src2 s2 ON s1.id = s2.id) dst(a, b, c, d) + * }}} + */ + override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) { + val relation = plan(ctx.relation).optionalMap(ctx.sample)(withSample) + mayApplyAliasPlan(ctx.tableAlias, relation) + } + + /** + * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as + * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. We could add alias names for output columns, for example: + * {{{ + * SELECT col1, col2 FROM testData AS t(col1, col2) + * }}} + */ + override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { + val relation = plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample) + if (ctx.tableAlias.strictIdentifier == null) { + // For un-aliased subqueries, use a default alias name that is not likely to conflict with + // normal subquery names, so that parent operators can only access the columns in subquery by + // unqualified names. Users can still use this special qualifier to access columns if they + // know it, but that's not recommended. + SubqueryAlias("__auto_generated_subquery_name", relation) + } else { + mayApplyAliasPlan(ctx.tableAlias, relation) + } + } + + /** + * Create an alias ([[SubqueryAlias]]) for a [[LogicalPlan]]. + */ + private def aliasPlan(alias: ParserRuleContext, plan: LogicalPlan): LogicalPlan = { + SubqueryAlias(alias.getText, plan) + } + + /** + * Create a Sequence of Strings for a parenthesis enclosed alias list. + */ + override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) { + visitIdentifierSeq(ctx.identifierSeq) + } + + /** + * Create a Sequence of Strings for an identifier list. + */ + override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) { + ctx.identifier.asScala.map(_.getText) + } + + /* ******************************************************************************************** + * Table Identifier parsing + * ******************************************************************************************** */ + /** + * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern. + */ + override def visitTableIdentifier( + ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) { + TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) + } + + /** + * Create a [[FunctionIdentifier]] from a 'functionName' or 'databaseName'.'functionName' pattern. + */ + override def visitFunctionIdentifier( + ctx: FunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) { + FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText)) + } + + /* ******************************************************************************************** + * Expression parsing + * ******************************************************************************************** */ + /** + * Create sequence of expressions from the given sequence of contexts. + */ + private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = { + trees.asScala.map(expression) + } + + /** + * Create a star (i.e. all) expression; this selects all elements (in the specified object). + * Both un-targeted (global) and targeted aliases are supported. + */ + override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) { + UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText))) + } + + /** + * Create an aliased expression if an alias is specified. Both single and multi-aliases are + * supported. + */ + override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else if (ctx.identifierList != null) { + MultiAlias(e, visitIdentifierList(ctx.identifierList)) + } else { + e + } + } + + /** + * Combine a number of boolean expressions into a balanced expression tree. These expressions are + * either combined by a logical [[And]] or a logical [[Or]]. + * + * A balanced binary tree is created because regular left recursive trees cause considerable + * performance degradations and can cause stack overflows. + */ + override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) { + val expressionType = ctx.operator.getType + val expressionCombiner = expressionType match { + case HoodieSqlBaseParser.AND => And.apply _ + case HoodieSqlBaseParser.OR => Or.apply _ + } + + // Collect all similar left hand contexts. + val contexts = ArrayBuffer(ctx.right) + var current = ctx.left + def collectContexts: Boolean = current match { + case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType => + contexts += lbc.right + current = lbc.left + true + case _ => + contexts += current + false + } + while (collectContexts) { + // No body - all updates take place in the collectContexts. + } + + // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them + // into expressions. + val expressions = contexts.reverseMap(expression) + + // Create a balanced tree. + def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match { + case 0 => + expressions(low) + case 1 => + expressionCombiner(expressions(low), expressions(high)) + case x => + val mid = low + x / 2 + expressionCombiner( + reduceToExpressionTree(low, mid), + reduceToExpressionTree(mid + 1, high)) + } + reduceToExpressionTree(0, expressions.size - 1) + } + + /** + * Invert a boolean expression. + */ + override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) { + Not(expression(ctx.booleanExpression())) + } + + /** + * Create a filtering correlated sub-query (EXISTS). + */ + override def visitExists(ctx: ExistsContext): Expression = { + Exists(plan(ctx.query)) + } + + /** + * Create a comparison expression. This compares two expressions. The following comparison + * operators are supported: + * - Equal: '=' or '==' + * - Null-safe Equal: '<=>' + * - Not Equal: '<>' or '!=' + * - Less than: '<' + * - Less then or Equal: '<=' + * - Greater than: '>' + * - Greater then or Equal: '>=' + */ + override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode] + operator.getSymbol.getType match { + case HoodieSqlBaseParser.EQ => + EqualTo(left, right) + case HoodieSqlBaseParser.NSEQ => + EqualNullSafe(left, right) + case HoodieSqlBaseParser.NEQ | HoodieSqlBaseParser.NEQJ => + Not(EqualTo(left, right)) + case HoodieSqlBaseParser.LT => + LessThan(left, right) + case HoodieSqlBaseParser.LTE => + LessThanOrEqual(left, right) + case HoodieSqlBaseParser.GT => + GreaterThan(left, right) + case HoodieSqlBaseParser.GTE => + GreaterThanOrEqual(left, right) + } + } + + /** + * Create a predicated expression. A predicated expression is a normal expression with a + * predicate attached to it, for example: + * {{{ + * a + 1 IS NULL + * }}} + */ + override def visitPredicated(ctx: PredicatedContext): Expression = withOrigin(ctx) { + val e = expression(ctx.valueExpression) + if (ctx.predicate != null) { + withPredicate(e, ctx.predicate) + } else { + e + } + } + + /** + * Add a predicate to the given expression. Supported expressions are: + * - (NOT) BETWEEN + * - (NOT) IN + * - (NOT) LIKE + * - (NOT) RLIKE + * - IS (NOT) NULL. + * - IS (NOT) DISTINCT FROM + */ + private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) { + // Invert a predicate if it has a valid NOT clause. + def invertIfNotDefined(e: Expression): Expression = ctx.NOT match { + case null => e + case not => Not(e) + } + + def getValueExpressions(e: Expression): Seq[Expression] = e match { + case c: CreateNamedStruct => c.valExprs + case other => Seq(other) + } + + // Create the predicate. + ctx.kind.getType match { + case HoodieSqlBaseParser.BETWEEN => + // BETWEEN is translated to lower <= e && e <= upper + invertIfNotDefined(And( + GreaterThanOrEqual(e, expression(ctx.lower)), + LessThanOrEqual(e, expression(ctx.upper)))) + case HoodieSqlBaseParser.IN if ctx.query != null => + invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query)))) + case HoodieSqlBaseParser.IN => + invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) + case HoodieSqlBaseParser.LIKE => + invertIfNotDefined(Like(e, expression(ctx.pattern))) + case HoodieSqlBaseParser.RLIKE => + invertIfNotDefined(RLike(e, expression(ctx.pattern))) + case HoodieSqlBaseParser.NULL if ctx.NOT != null => + IsNotNull(e) + case HoodieSqlBaseParser.NULL => + IsNull(e) + case HoodieSqlBaseParser.DISTINCT if ctx.NOT != null => + EqualNullSafe(e, expression(ctx.right)) + case HoodieSqlBaseParser.DISTINCT => + Not(EqualNullSafe(e, expression(ctx.right))) + } + } + + /** + * Create a binary arithmetic expression. The following arithmetic operators are supported: + * - Multiplication: '*' + * - Division: '/' + * - Hive Long Division: 'DIV' + * - Modulo: '%' + * - Addition: '+' + * - Subtraction: '-' + * - Binary AND: '&' + * - Binary XOR + * - Binary OR: '|' + */ + override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + ctx.operator.getType match { + case HoodieSqlBaseParser.ASTERISK => + Multiply(left, right) + case HoodieSqlBaseParser.SLASH => + Divide(left, right) + case HoodieSqlBaseParser.PERCENT => + Remainder(left, right) + case HoodieSqlBaseParser.DIV => + Cast(Divide(left, right), LongType) + case HoodieSqlBaseParser.PLUS => + Add(left, right) + case HoodieSqlBaseParser.MINUS => + Subtract(left, right) + case HoodieSqlBaseParser.CONCAT_PIPE => + Concat(left :: right :: Nil) + case HoodieSqlBaseParser.AMPERSAND => + BitwiseAnd(left, right) + case HoodieSqlBaseParser.HAT => + BitwiseXor(left, right) + case HoodieSqlBaseParser.PIPE => + BitwiseOr(left, right) + } + } + + /** + * Create a unary arithmetic expression. The following arithmetic operators are supported: + * - Plus: '+' + * - Minus: '-' + * - Bitwise Not: '~' + */ + override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) { + val value = expression(ctx.valueExpression) + ctx.operator.getType match { + case HoodieSqlBaseParser.PLUS => + value + case HoodieSqlBaseParser.MINUS => + UnaryMinus(value) + case HoodieSqlBaseParser.TILDE => + BitwiseNot(value) + } + } + + /** + * Create a [[Cast]] expression. + */ + override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { + Cast(expression(ctx.expression), visitSparkDataType(ctx.dataType)) + } + + /** + * Create a [[CreateStruct]] expression. + */ + override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) { + CreateStruct(ctx.argument.asScala.map(expression)) + } + + /** + * Create a [[First]] expression. + */ + override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + First(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + } + + /** + * Create a [[Last]] expression. + */ + override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + Last(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + } + + /** + * Create a Position expression. + */ + override def visitPosition(ctx: PositionContext): Expression = withOrigin(ctx) { + new StringLocate(expression(ctx.substr), expression(ctx.str)) + } + + /** + * Create a Extract expression. + */ + override def visitExtract(ctx: ExtractContext): Expression = withOrigin(ctx) { + ctx.field.getText.toUpperCase(Locale.ROOT) match { + case "YEAR" => + Year(expression(ctx.source)) + case "QUARTER" => + Quarter(expression(ctx.source)) + case "MONTH" => + Month(expression(ctx.source)) + case "WEEK" => + WeekOfYear(expression(ctx.source)) + case "DAY" => + DayOfMonth(expression(ctx.source)) + case "DAYOFWEEK" => + DayOfWeek(expression(ctx.source)) + case "HOUR" => + Hour(expression(ctx.source)) + case "MINUTE" => + Minute(expression(ctx.source)) + case "SECOND" => + Second(expression(ctx.source)) + case other => + throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) + } + } + + /** + * Create a (windowed) Function expression. + */ + override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { + def replaceFunctions( + funcID: FunctionIdentifier, + ctx: FunctionCallContext): FunctionIdentifier = { + val opt = ctx.trimOption + if (opt != null) { + if (ctx.qualifiedName.getText.toLowerCase(Locale.ROOT) != "trim") { + throw new ParseException(s"The specified function ${ctx.qualifiedName.getText} " + + s"doesn't support with option ${opt.getText}.", ctx) + } + opt.getType match { + case HoodieSqlBaseParser.BOTH => funcID + case HoodieSqlBaseParser.LEADING => funcID.copy(funcName = "ltrim") + case HoodieSqlBaseParser.TRAILING => funcID.copy(funcName = "rtrim") + case _ => throw new ParseException("Function trim doesn't support with " + + s"type ${opt.getType}. Please use BOTH, LEADING or Trailing as trim type", ctx) + } + } else { + funcID + } + } + // Create the function call. + val name = ctx.qualifiedName.getText + val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) + val arguments = ctx.argument.asScala.map(expression) match { + case Seq(UnresolvedStar(None)) + if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct => + // Transform COUNT(*) into COUNT(1). + Seq(Literal(1)) + case expressions => + expressions + } + val funcId = replaceFunctions(visitFunctionName(ctx.qualifiedName), ctx) + val function = UnresolvedFunction(funcId, arguments, isDistinct) + + + // Check if the function is evaluated in a windowed context. + ctx.windowSpec match { + case spec: WindowRefContext => + UnresolvedWindowExpression(function, visitWindowRef(spec)) + case spec: WindowDefContext => + WindowExpression(function, visitWindowDef(spec)) + case _ => function + } + } + + /** + * Create a function database (optional) and name pair. + */ + protected def visitFunctionName(ctx: QualifiedNameContext): FunctionIdentifier = { + ctx.identifier().asScala.map(_.getText) match { + case Seq(db, fn) => FunctionIdentifier(fn, Option(db)) + case Seq(fn) => FunctionIdentifier(fn, None) + case other => throw new ParseException(s"Unsupported function name '${ctx.getText}'", ctx) + } + } + + /** + * Create an [[LambdaFunction]]. + */ + override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) { + val arguments = ctx.IDENTIFIER().asScala.map { name => + UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts) + } + val function = expression(ctx.expression).transformUp { + case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts) + } + LambdaFunction(function, arguments) + } + + /** + * Create a reference to a window frame, i.e. [[WindowSpecReference]]. + */ + override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) { + WindowSpecReference(ctx.identifier.getText) + } + + /** + * Create a window definition, i.e. [[WindowSpecDefinition]]. + */ + override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) { + // CLUSTER BY ... | PARTITION BY ... ORDER BY ... + val partition = ctx.partition.asScala.map(expression) + val order = ctx.sortItem.asScala.map(visitSortItem) + + // RANGE/ROWS BETWEEN ... + val frameSpecOption = Option(ctx.windowFrame).map { frame => + val frameType = frame.frameType.getType match { + case HoodieSqlBaseParser.RANGE => RangeFrame + case HoodieSqlBaseParser.ROWS => RowFrame + } + + SpecifiedWindowFrame( + frameType, + visitFrameBound(frame.start), + Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow)) + } + + WindowSpecDefinition( + partition, + order, + frameSpecOption.getOrElse(UnspecifiedFrame)) + } + + /** + * Create or resolve a frame boundary expressions. + */ + override def visitFrameBound(ctx: FrameBoundContext): Expression = withOrigin(ctx) { + def value: Expression = { + val e = expression(ctx.expression) + validate(e.resolved && e.foldable, "Frame bound value must be a literal.", ctx) + e + } + + ctx.boundType.getType match { + case HoodieSqlBaseParser.PRECEDING if ctx.UNBOUNDED != null => + UnboundedPreceding + case HoodieSqlBaseParser.PRECEDING => + UnaryMinus(value) + case HoodieSqlBaseParser.CURRENT => + CurrentRow + case HoodieSqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null => + UnboundedFollowing + case HoodieSqlBaseParser.FOLLOWING => + value + } + } + + /** + * Create a [[CreateStruct]] expression. + */ + override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { + CreateStruct(ctx.namedExpression().asScala.map(expression)) + } + + /** + * Create a [[ScalarSubquery]] expression. + */ + override def visitSubqueryExpression( + ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) { + ScalarSubquery(plan(ctx.query)) + } + + /** + * Create a value based [[CaseWhen]] expression. This has the following SQL form: + * {{{ + * CASE [expression] + * WHEN [value] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + */ + override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { + val e = expression(ctx.value) + val branches = ctx.whenClause.asScala.map { wCtx => + (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) + } + CaseWhen(branches, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax: + * {{{ + * CASE + * WHEN [predicate] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + * + * @param ctx the parse tree + * */ + override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) { + val branches = ctx.whenClause.asScala.map { wCtx => + (expression(wCtx.condition), expression(wCtx.result)) + } + CaseWhen(branches, Option(ctx.elseExpression).map(expression)) + } + + /** + * Currently only regex in expressions of SELECT statements are supported; in other + * places, e.g., where `(a)?+.+` = 2, regex are not meaningful. + */ + private def canApplyRegex(ctx: ParserRuleContext): Boolean = withOrigin(ctx) { + var parent = ctx.getParent + var result = false + while (parent != null) { + if (parent.isInstanceOf[NamedExpressionContext]) { + result = true + } + parent = parent.getParent + } + result + } + + /** + * Create a dereference expression. The return type depends on the type of the parent. + * If the parent is an [[UnresolvedAttribute]], it can be a [[UnresolvedAttribute]] or + * a [[UnresolvedRegex]] for regex quoted in ``; if the parent is some other expression, + * it can be [[UnresolvedExtractValue]]. + */ + override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) { + val attr = ctx.fieldName.getText + expression(ctx.base) match { + case unresolved_attr @ UnresolvedAttribute(nameParts) => + ctx.fieldName.getStart.getText match { + case escapedIdentifier(columnNameRegex) + if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) => + UnresolvedRegex(columnNameRegex, Some(unresolved_attr.name), + conf.caseSensitiveAnalysis) + case _ => + UnresolvedAttribute(nameParts :+ attr) + } + case e => + UnresolvedExtractValue(e, Literal(attr)) + } + } + + /** + * Create an [[UnresolvedAttribute]] expression or a [[UnresolvedRegex]] if it is a regex + * quoted in `` + */ + override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) { + ctx.getStart.getText match { + case escapedIdentifier(columnNameRegex) + if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) => + UnresolvedRegex(columnNameRegex, None, conf.caseSensitiveAnalysis) + case _ => + UnresolvedAttribute.quoted(ctx.getText) + } + + } + + /** + * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array. + */ + override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) { + UnresolvedExtractValue(expression(ctx.value), expression(ctx.index)) + } + + /** + * Create an expression for an expression between parentheses. This is need because the ANTLR + * visitor cannot automatically convert the nested context into an expression. + */ + override def visitParenthesizedExpression( + ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) { + expression(ctx.expression) + } + + /** + * Create a [[SortOrder]] expression. + */ + override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) { + val direction = if (ctx.DESC != null) { + Descending + } else { + Ascending + } + val nullOrdering = if (ctx.FIRST != null) { + NullsFirst + } else if (ctx.LAST != null) { + NullsLast + } else { + direction.defaultNullOrdering + } + SortOrder(expression(ctx.expression), direction, nullOrdering, Set.empty) + } + + /** + * Create a typed Literal expression. A typed literal has the following SQL syntax: + * {{{ + * [TYPE] '[VALUE]' + * }}} + * Currently Date, Timestamp and Binary typed literals are supported. + */ + override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { + val value = string(ctx.STRING) + val valueType = ctx.identifier.getText.toUpperCase(Locale.ROOT) + try { + valueType match { + case "DATE" => + Literal(Date.valueOf(value)) + case "TIMESTAMP" => + Literal(Timestamp.valueOf(value)) + case "X" => + val padding = if (value.length % 2 != 0) "0" else "" + Literal(DatatypeConverter.parseHexBinary(padding + value)) + case other => + throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) + } + } catch { + case e: IllegalArgumentException => + val message = Option(e.getMessage).getOrElse(s"Exception parsing $valueType") + throw new ParseException(message, ctx) + } + } + + /** + * Create a NULL literal expression. + */ + override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) { + Literal(null) + } + + /** + * Create a Boolean literal expression. + */ + override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) { + if (ctx.getText.toBoolean) { + Literal.TrueLiteral + } else { + Literal.FalseLiteral + } + } + + /** + * Create an integral literal expression. The code selects the most narrow integral type + * possible, either a BigDecimal, a Long or an Integer is returned. + */ + override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) { + BigDecimal(ctx.getText) match { + case v if v.isValidInt => + Literal(v.intValue()) + case v if v.isValidLong => + Literal(v.longValue()) + case v => Literal(v.underlying()) + } + } + + /** + * Create a decimal literal for a regular decimal number. + */ + override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(BigDecimal(ctx.getText).underlying()) + } + + /** Create a numeric literal expression. */ + private def numericLiteral + (ctx: NumberContext, minValue: BigDecimal, maxValue: BigDecimal, typeName: String) + (converter: String => Any): Literal = withOrigin(ctx) { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + try { + val rawBigDecimal = BigDecimal(rawStrippedQualifier) + if (rawBigDecimal < minValue || rawBigDecimal > maxValue) { + throw new ParseException(s"Numeric literal ${rawStrippedQualifier} does not " + + s"fit in range [${minValue}, ${maxValue}] for type ${typeName}", ctx) + } + Literal(converter(rawStrippedQualifier)) + } catch { + case e: NumberFormatException => + throw new ParseException(e.getMessage, ctx) + } + } + + /** + * Create a Byte Literal expression. + */ + override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = { + numericLiteral(ctx, Byte.MinValue, Byte.MaxValue, ByteType.simpleString)(_.toByte) + } + + /** + * Create a Short Literal expression. + */ + override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = { + numericLiteral(ctx, Short.MinValue, Short.MaxValue, ShortType.simpleString)(_.toShort) + } + + /** + * Create a Long Literal expression. + */ + override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = { + numericLiteral(ctx, Long.MinValue, Long.MaxValue, LongType.simpleString)(_.toLong) + } + + /** + * Create a Double Literal expression. + */ + override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = { + numericLiteral(ctx, Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble) + } + + /** + * Create a BigDecimal Literal expression. + */ + override def visitBigDecimalLiteral(ctx: BigDecimalLiteralContext): Literal = { + val raw = ctx.getText.substring(0, ctx.getText.length - 2) + try { + Literal(BigDecimal(raw).underlying()) + } catch { + case e: AnalysisException => + throw new ParseException(e.message, ctx) + } + } + + /** + * Create a String literal expression. + */ + override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { + Literal(createString(ctx)) + } + + /** + * Create a String from a string literal context. This supports multiple consecutive string + * literals, these are concatenated, for example this expression "'hello' 'world'" will be + * converted into "helloworld". + * + * Special characters can be escaped by using Hive/C-style escaping. + */ + private def createString(ctx: StringLiteralContext): String = { + if (conf.escapedStringLiterals) { + ctx.STRING().asScala.map(stringWithoutUnescape).mkString + } else { + ctx.STRING().asScala.map(string).mkString + } + } + + /** + * Create a [[CalendarInterval]] literal expression. An interval expression can contain multiple + * unit value pairs, for instance: interval 2 months 2 days. + */ + override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) { + val intervals = ctx.intervalField.asScala.map(visitIntervalField) + validate(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx) + Literal(intervals.reduce(_.add(_))) + } + + /** + * Create a [[CalendarInterval]] for a unit value pair. Two unit configuration types are + * supported: + * - Single unit. + * - From-To unit (only 'YEAR TO MONTH' and 'DAY TO SECOND' are supported). + */ + override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) { + import ctx._ + val s = value.getText + try { + val unitText = unit.getText.toLowerCase(Locale.ROOT) + val interval = (unitText, Option(to).map(_.getText.toLowerCase(Locale.ROOT))) match { + case (u, None) if u.endsWith("s") => + // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/... + CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s) + case (u, None) => + CalendarInterval.fromSingleUnitString(u, s) + case ("year", Some("month")) => + CalendarInterval.fromYearMonthString(s) + case ("day", Some("second")) => + CalendarInterval.fromDayTimeString(s) + case (from, Some(t)) => + throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx) + } + validate(interval != null, "No interval can be constructed", ctx) + interval + } catch { + // Handle Exceptions thrown by CalendarInterval + case e: IllegalArgumentException => + val pe = new ParseException(e.getMessage, ctx) + pe.setStackTrace(e.getStackTrace) + throw pe + } + } + + /* ******************************************************************************************** + * DataType parsing + * ******************************************************************************************** */ + /** + * Create a Spark DataType. + */ + private def visitSparkDataType(ctx: DataTypeContext): DataType = { + HiveStringType.replaceCharType(typedVisit(ctx)) + } + + /** + * Resolve/create a primitive type. + */ + override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { + val dataType = ctx.identifier.getText.toLowerCase(Locale.ROOT) + (dataType, ctx.INTEGER_VALUE().asScala.toList) match { + case ("boolean", Nil) => BooleanType + case ("tinyint" | "byte", Nil) => ByteType + case ("smallint" | "short", Nil) => ShortType + case ("int" | "integer", Nil) => IntegerType + case ("bigint" | "long", Nil) => LongType + case ("float", Nil) => FloatType + case ("double", Nil) => DoubleType + case ("date", Nil) => DateType + case ("timestamp", Nil) => TimestampType + case ("string", Nil) => StringType + case ("char", length :: Nil) => CharType(length.getText.toInt) + case ("varchar", length :: Nil) => VarcharType(length.getText.toInt) + case ("binary", Nil) => BinaryType + case ("decimal", Nil) => DecimalType.USER_DEFAULT + case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0) + case ("decimal", precision :: scale :: Nil) => + DecimalType(precision.getText.toInt, scale.getText.toInt) + case (dt, params) => + val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt + throw new ParseException(s"DataType $dtStr is not supported.", ctx) + } + } + + /** + * Create a complex DataType. Arrays, Maps and Structures are supported. + */ + override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) { + ctx.complex.getType match { + case HoodieSqlBaseParser.ARRAY => + ArrayType(typedVisit(ctx.dataType(0))) + case HoodieSqlBaseParser.MAP => + MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) + case HoodieSqlBaseParser.STRUCT => + StructType(Option(ctx.complexColTypeList).toSeq.flatMap(visitComplexColTypeList)) + } + } + + /** + * Create top level table schema. + */ + protected def createSchema(ctx: ColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.colType().asScala.map(visitColType) + } + + /** + * Create a top level [[StructField]] from a column definition. + */ + override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + + val builder = new MetadataBuilder + // Add comment to metadata + if (STRING != null) { + builder.putString("comment", string(STRING)) + } + // Add Hive type string to metadata. + val rawDataType = typedVisit[DataType](ctx.dataType) + val cleanedDataType = HiveStringType.replaceCharType(rawDataType) + if (rawDataType != cleanedDataType) { + builder.putString(HIVE_TYPE_STRING, rawDataType.catalogString) + } + + StructField( + identifier.getText, + cleanedDataType, + nullable = true, + builder.build()) + } + + /** + * Create a [[StructType]] from a sequence of [[StructField]]s. + */ + protected def createStructType(ctx: ComplexColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitComplexColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitComplexColTypeList( + ctx: ComplexColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.complexColType().asScala.map(visitComplexColType) + } + + /** + * Create a [[StructField]] from a column definition. + */ + override def visitComplexColType(ctx: ComplexColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + val structField = StructField(identifier.getText, typedVisit(dataType), nullable = true) + if (STRING == null) structField else structField.withComment(string(STRING)) } } diff --git a/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/unsafe/types/CalendarInterval.java b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/unsafe/types/CalendarInterval.java new file mode 100644 index 0000000000000..92a7e412ad329 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark2/src/main/scala/org/apache/spark/sql/hudi/parser/unsafe/types/CalendarInterval.java @@ -0,0 +1,360 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hudi.parser.unsafe.types; + +import java.io.Serializable; +import java.util.Locale; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * The internal representation of interval type. + */ +public final class CalendarInterval implements Serializable { + public static final long MICROS_PER_MILLI = 1000L; + public static final long MICROS_PER_SECOND = MICROS_PER_MILLI * 1000; + public static final long MICROS_PER_MINUTE = MICROS_PER_SECOND * 60; + public static final long MICROS_PER_HOUR = MICROS_PER_MINUTE * 60; + public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24; + public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7; + + /** + * A function to generate regex which matches interval string's unit part like "3 years". + * + * First, we can leave out some units in interval string, and we only care about the value of + * unit, so here we use non-capturing group to wrap the actual regex. + * At the beginning of the actual regex, we should match spaces before the unit part. + * Next is the number part, starts with an optional "-" to represent negative value. We use + * capturing group to wrap this part as we need the value later. + * Finally is the unit name, ends with an optional "s". + */ + private static String unitRegex(String unit) { + return "(?:\\s+(-?\\d+)\\s+" + unit + "s?)?"; + } + + private static Pattern p = Pattern.compile("interval" + unitRegex("year") + unitRegex("month") + + unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") + + unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond")); + + private static Pattern yearMonthPattern = + Pattern.compile("^(?:['|\"])?([+|-])?(\\d+)-(\\d+)(?:['|\"])?$"); + + private static Pattern dayTimePattern = + Pattern.compile("^(?:['|\"])?([+|-])?(\\d+) (\\d+):(\\d+):(\\d+)(\\.(\\d+))?(?:['|\"])?$"); + + private static Pattern quoteTrimPattern = Pattern.compile("^(?:['|\"])?(.*?)(?:['|\"])?$"); + + private static long toLong(String s) { + if (s == null) { + return 0; + } else { + return Long.parseLong(s); + } + } + + /** + * Convert a string to CalendarInterval. Return null if the input string is not a valid interval. + * This method is case-sensitive and all characters in the input string should be in lower case. + */ + public static CalendarInterval fromString(String s) { + if (s == null) { + return null; + } + s = s.trim(); + Matcher m = p.matcher(s); + if (!m.matches() || s.equals("interval")) { + return null; + } else { + long months = toLong(m.group(1)) * 12 + toLong(m.group(2)); + long microseconds = toLong(m.group(3)) * MICROS_PER_WEEK; + microseconds += toLong(m.group(4)) * MICROS_PER_DAY; + microseconds += toLong(m.group(5)) * MICROS_PER_HOUR; + microseconds += toLong(m.group(6)) * MICROS_PER_MINUTE; + microseconds += toLong(m.group(7)) * MICROS_PER_SECOND; + microseconds += toLong(m.group(8)) * MICROS_PER_MILLI; + microseconds += toLong(m.group(9)); + return new CalendarInterval((int) months, microseconds); + } + } + + /** + * Convert a string to CalendarInterval. Unlike fromString, this method is case-insensitive and + * will throw IllegalArgumentException when the input string is not a valid interval. + * + * @throws IllegalArgumentException if the string is not a valid internal. + */ + public static CalendarInterval fromCaseInsensitiveString(String s) { + if (s == null || s.trim().isEmpty()) { + throw new IllegalArgumentException("Interval cannot be null or blank."); + } + String sInLowerCase = s.trim().toLowerCase(Locale.ROOT); + String interval = + sInLowerCase.startsWith("interval ") ? sInLowerCase : "interval " + sInLowerCase; + CalendarInterval cal = fromString(interval); + if (cal == null) { + throw new IllegalArgumentException("Invalid interval: " + s); + } + return cal; + } + + public static long toLongWithRange(String fieldName, + String s, long minValue, long maxValue) throws IllegalArgumentException { + long result = 0; + if (s != null) { + result = Long.parseLong(s); + if (result < minValue || result > maxValue) { + throw new IllegalArgumentException(String.format("%s %d outside range [%d, %d]", + fieldName, result, minValue, maxValue)); + } + } + return result; + } + + /** + * Parse YearMonth string in form: [-]YYYY-MM + * + * adapted from HiveIntervalYearMonth.valueOf + */ + public static CalendarInterval fromYearMonthString(String s) throws IllegalArgumentException { + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException("Interval year-month string was null"); + } + s = s.trim(); + Matcher m = yearMonthPattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match year-month format of 'y-m': " + s); + } else { + try { + int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1; + int years = (int) toLongWithRange("year", m.group(2), 0, Integer.MAX_VALUE); + int months = (int) toLongWithRange("month", m.group(3), 0, 11); + result = new CalendarInterval(sign * (years * 12 + months), 0); + } catch (Exception e) { + throw new IllegalArgumentException( + "Error parsing interval year-month string: " + e.getMessage(), e); + } + } + return result; + } + + /** + * Parse dayTime string in form: [-]d HH:mm:ss.nnnnnnnnn + * + * adapted from HiveIntervalDayTime.valueOf + */ + public static CalendarInterval fromDayTimeString(String s) throws IllegalArgumentException { + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException("Interval day-time string was null"); + } + s = s.trim(); + Matcher m = dayTimePattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match day-time format of 'd h:m:s.n': " + s); + } else { + try { + int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1; + long days = toLongWithRange("day", m.group(2), 0, Integer.MAX_VALUE); + long hours = toLongWithRange("hour", m.group(3), 0, 23); + long minutes = toLongWithRange("minute", m.group(4), 0, 59); + long seconds = toLongWithRange("second", m.group(5), 0, 59); + // Hive allow nanosecond precision interval + String nanoStr = m.group(7) == null ? null : (m.group(7) + "000000000").substring(0, 9); + long nanos = toLongWithRange("nanosecond", nanoStr, 0L, 999999999L); + result = new CalendarInterval(0, sign * ( + days * MICROS_PER_DAY + hours * MICROS_PER_HOUR + minutes * MICROS_PER_MINUTE + + seconds * MICROS_PER_SECOND + nanos / 1000L)); + } catch (Exception e) { + throw new IllegalArgumentException( + "Error parsing interval day-time string: " + e.getMessage(), e); + } + } + return result; + } + + public static CalendarInterval fromSingleUnitString(String unit, String s) + throws IllegalArgumentException { + + CalendarInterval result = null; + if (s == null) { + throw new IllegalArgumentException(String.format("Interval %s string was null", unit)); + } + s = s.trim(); + Matcher m = quoteTrimPattern.matcher(s); + if (!m.matches()) { + throw new IllegalArgumentException( + "Interval string does not match day-time format of 'd h:m:s.n': " + s); + } else { + try { + switch (unit) { + case "year": + int year = (int) toLongWithRange("year", m.group(1), + Integer.MIN_VALUE / 12, Integer.MAX_VALUE / 12); + result = new CalendarInterval(year * 12, 0L); + break; + case "month": + int month = (int) toLongWithRange("month", m.group(1), + Integer.MIN_VALUE, Integer.MAX_VALUE); + result = new CalendarInterval(month, 0L); + break; + case "week": + long week = toLongWithRange("week", m.group(1), + Long.MIN_VALUE / MICROS_PER_WEEK, Long.MAX_VALUE / MICROS_PER_WEEK); + result = new CalendarInterval(0, week * MICROS_PER_WEEK); + break; + case "day": + long day = toLongWithRange("day", m.group(1), + Long.MIN_VALUE / MICROS_PER_DAY, Long.MAX_VALUE / MICROS_PER_DAY); + result = new CalendarInterval(0, day * MICROS_PER_DAY); + break; + case "hour": + long hour = toLongWithRange("hour", m.group(1), + Long.MIN_VALUE / MICROS_PER_HOUR, Long.MAX_VALUE / MICROS_PER_HOUR); + result = new CalendarInterval(0, hour * MICROS_PER_HOUR); + break; + case "minute": + long minute = toLongWithRange("minute", m.group(1), + Long.MIN_VALUE / MICROS_PER_MINUTE, Long.MAX_VALUE / MICROS_PER_MINUTE); + result = new CalendarInterval(0, minute * MICROS_PER_MINUTE); + break; + case "second": { + long micros = parseSecondNano(m.group(1)); + result = new CalendarInterval(0, micros); + break; + } + case "millisecond": + long millisecond = toLongWithRange("millisecond", m.group(1), + Long.MIN_VALUE / MICROS_PER_MILLI, Long.MAX_VALUE / MICROS_PER_MILLI); + result = new CalendarInterval(0, millisecond * MICROS_PER_MILLI); + break; + case "microsecond": { + long micros = Long.parseLong(m.group(1)); + result = new CalendarInterval(0, micros); + break; + } + } + } catch (Exception e) { + throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e); + } + } + return result; + } + + /** + * Parse second_nano string in ss.nnnnnnnnn format to microseconds + */ + public static long parseSecondNano(String secondNano) throws IllegalArgumentException { + String[] parts = secondNano.split("\\."); + if (parts.length == 1) { + return toLongWithRange("second", parts[0], Long.MIN_VALUE / MICROS_PER_SECOND, + Long.MAX_VALUE / MICROS_PER_SECOND) * MICROS_PER_SECOND; + + } else if (parts.length == 2) { + long seconds = parts[0].equals("") ? 0L : toLongWithRange("second", parts[0], + Long.MIN_VALUE / MICROS_PER_SECOND, Long.MAX_VALUE / MICROS_PER_SECOND); + long nanos = toLongWithRange("nanosecond", parts[1], 0L, 999999999L); + return seconds * MICROS_PER_SECOND + nanos / 1000L; + + } else { + throw new IllegalArgumentException( + "Interval string does not match second-nano format of ss.nnnnnnnnn"); + } + } + + public final int months; + public final long microseconds; + + public long milliseconds() { + return this.microseconds / MICROS_PER_MILLI; + } + + public CalendarInterval(int months, long microseconds) { + this.months = months; + this.microseconds = microseconds; + } + + public CalendarInterval add(CalendarInterval that) { + int months = this.months + that.months; + long microseconds = this.microseconds + that.microseconds; + return new CalendarInterval(months, microseconds); + } + + public CalendarInterval subtract(CalendarInterval that) { + int months = this.months - that.months; + long microseconds = this.microseconds - that.microseconds; + return new CalendarInterval(months, microseconds); + } + + public CalendarInterval negate() { + return new CalendarInterval(-this.months, -this.microseconds); + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || !(other instanceof CalendarInterval)) return false; + + CalendarInterval o = (CalendarInterval) other; + return this.months == o.months && this.microseconds == o.microseconds; + } + + @Override + public int hashCode() { + return 31 * months + (int) microseconds; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder("interval"); + + if (months != 0) { + appendUnit(sb, months / 12, "year"); + appendUnit(sb, months % 12, "month"); + } + + if (microseconds != 0) { + long rest = microseconds; + appendUnit(sb, rest / MICROS_PER_WEEK, "week"); + rest %= MICROS_PER_WEEK; + appendUnit(sb, rest / MICROS_PER_DAY, "day"); + rest %= MICROS_PER_DAY; + appendUnit(sb, rest / MICROS_PER_HOUR, "hour"); + rest %= MICROS_PER_HOUR; + appendUnit(sb, rest / MICROS_PER_MINUTE, "minute"); + rest %= MICROS_PER_MINUTE; + appendUnit(sb, rest / MICROS_PER_SECOND, "second"); + rest %= MICROS_PER_SECOND; + appendUnit(sb, rest / MICROS_PER_MILLI, "millisecond"); + rest %= MICROS_PER_MILLI; + appendUnit(sb, rest, "microsecond"); + } else if (months == 0) { + sb.append(" 0 microseconds"); + } + + return sb.toString(); + } + + private void appendUnit(StringBuilder sb, long value, String unit) { + if (value != 0) { + sb.append(' ').append(value).append(' ').append(unit).append('s'); + } + } +} diff --git a/hudi-spark-datasource/hudi-spark3-common/pom.xml b/hudi-spark-datasource/hudi-spark3-common/pom.xml index affa987372963..20faa6744178f 100644 --- a/hudi-spark-datasource/hudi-spark3-common/pom.xml +++ b/hudi-spark-datasource/hudi-spark3-common/pom.xml @@ -154,6 +154,42 @@ org.jacoco jacoco-maven-plugin + + org.antlr + antlr4-maven-plugin + ${antlr.version} + + + + antlr4 + + + + + true + true + ../hudi-spark3-common/src/main/antlr4/ + ../hudi-spark3-common/src/main/antlr4/imports + + + + org.antlr + antlr4-maven-plugin + ${antlr.version} + + + + antlr4 + + + + + true + true + ../hudi-spark3-common/src/main/antlr4/ + ../hudi-spark3-common/src/main/antlr4/imports + + @@ -244,4 +280,4 @@ - \ No newline at end of file + diff --git a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/Spark3Adapter.scala b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/Spark3Adapter.scala index a1c41e80ab362..ca1ad1afdf58a 100644 --- a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/Spark3Adapter.scala +++ b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/adapter/Spark3Adapter.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.adapter import org.apache.hudi.Spark3RowSerDe import org.apache.hudi.client.utils.SparkRowSerDe import org.apache.hudi.spark3.internal.ReflectUtil -import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Expression, Like} @@ -31,11 +30,10 @@ import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.Table import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.datasources.{FilePartition, LogicalRelation, PartitionedFile, Spark3ParsePartitionUtil, SparkParsePartitionUtil} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hudi.SparkAdapter import org.apache.spark.sql.internal.SQLConf - -import scala.collection.JavaConverters.mapAsScalaMapConverter +import org.apache.spark.sql.{Row, SparkSession} /** * The adapter for spark3. diff --git a/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeTravelRelation.scala b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeTravelRelation.scala new file mode 100644 index 0000000000000..67f942881c39d --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3-common/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TimeTravelRelation.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} + +case class TimeTravelRelation( + table: LogicalPlan, + timestamp: Option[Expression], + version: Option[String]) extends Command { + override def children: Seq[LogicalPlan] = Seq.empty + + override def output: Seq[Attribute] = Nil + + override lazy val resolved: Boolean = false + + def withNewChildrenInternal(newChildren: IndexedSeq[LogicalPlan]): LogicalPlan = this +} diff --git a/hudi-spark-datasource/hudi-spark3.1.x/pom.xml b/hudi-spark-datasource/hudi-spark3.1.x/pom.xml index f6d9f7d557216..374b1722a365d 100644 --- a/hudi-spark-datasource/hudi-spark3.1.x/pom.xml +++ b/hudi-spark-datasource/hudi-spark3.1.x/pom.xml @@ -144,6 +144,24 @@ org.jacoco jacoco-maven-plugin + + org.antlr + antlr4-maven-plugin + ${antlr.version} + + + + antlr4 + + + + + true + true + ../hudi-spark3.1.x/src/main/antlr4 + ../hudi-spark3.1.x/src/main/antlr4/imports + + @@ -157,7 +175,7 @@ org.apache.spark spark-sql_2.12 - ${spark3.version} + ${spark3.1.version} true diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/antlr4/imports/SqlBase.g4 b/hudi-spark-datasource/hudi-spark3.1.x/src/main/antlr4/imports/SqlBase.g4 new file mode 100644 index 0000000000000..500f66c9e53f9 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/antlr4/imports/SqlBase.g4 @@ -0,0 +1,1852 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. + */ + +// The parser file is forked from spark 3.1.2's SqlBase.g4. +grammar SqlBase; + +@parser::members { + /** + * When false, INTERSECT is given the greater precedence over the other set + * operations (UNION, EXCEPT and MINUS) as per the SQL standard. + */ + public boolean legacy_setops_precedence_enbled = false; + + /** + * When false, a literal with an exponent would be converted into + * double type rather than decimal type. + */ + public boolean legacy_exponent_literal_as_decimal_enabled = false; + + /** + * When true, the behavior of keywords follows ANSI SQL standard. + */ + public boolean SQL_standard_keyword_behavior = false; +} + +@lexer::members { + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } + + /** + * This method will be called when we see '/*' and try to match it as a bracketed comment. + * If the next character is '+', it should be parsed as hint later, and we cannot match + * it as a bracketed comment. + * + * Returns true if the next character is '+'. + */ + public boolean isHint() { + int nextChar = _input.LA(1); + if (nextChar == '+') { + return true; + } else { + return false; + } + } +} + +singleStatement + : statement ';'* EOF + ; + +singleExpression + : namedExpression EOF + ; + +singleTableIdentifier + : tableIdentifier EOF + ; + +singleMultipartIdentifier + : multipartIdentifier EOF + ; + +singleFunctionIdentifier + : functionIdentifier EOF + ; + +singleDataType + : dataType EOF + ; + +singleTableSchema + : colTypeList EOF + ; + +statement + : query #statementDefault + | ctes? dmlStatementNoWith #dmlStatement + | USE NAMESPACE? multipartIdentifier #use + | CREATE namespace (IF NOT EXISTS)? multipartIdentifier + (commentSpec | + locationSpec | + (WITH (DBPROPERTIES | PROPERTIES) tablePropertyList))* #createNamespace + | ALTER namespace multipartIdentifier + SET (DBPROPERTIES | PROPERTIES) tablePropertyList #setNamespaceProperties + | ALTER namespace multipartIdentifier + SET locationSpec #setNamespaceLocation + | DROP namespace (IF EXISTS)? multipartIdentifier + (RESTRICT | CASCADE)? #dropNamespace + | SHOW (DATABASES | NAMESPACES) ((FROM | IN) multipartIdentifier)? + (LIKE? pattern=STRING)? #showNamespaces + | createTableHeader ('(' colTypeList ')')? tableProvider? + createTableClauses + (AS? query)? #createTable + | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier + LIKE source=tableIdentifier + (tableProvider | + rowFormat | + createFileFormat | + locationSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* #createTableLike + | replaceTableHeader ('(' colTypeList ')')? tableProvider? + createTableClauses + (AS? query)? #replaceTable + | ANALYZE TABLE multipartIdentifier partitionSpec? COMPUTE STATISTICS + (identifier | FOR COLUMNS identifierSeq | FOR ALL COLUMNS)? #analyze + | ALTER TABLE multipartIdentifier + ADD (COLUMN | COLUMNS) + columns=qualifiedColTypeWithPositionList #addTableColumns + | ALTER TABLE multipartIdentifier + ADD (COLUMN | COLUMNS) + '(' columns=qualifiedColTypeWithPositionList ')' #addTableColumns + | ALTER TABLE table=multipartIdentifier + RENAME COLUMN + from=multipartIdentifier TO to=errorCapturingIdentifier #renameTableColumn + | ALTER TABLE multipartIdentifier + DROP (COLUMN | COLUMNS) + '(' columns=multipartIdentifierList ')' #dropTableColumns + | ALTER TABLE multipartIdentifier + DROP (COLUMN | COLUMNS) columns=multipartIdentifierList #dropTableColumns + | ALTER (TABLE | VIEW) from=multipartIdentifier + RENAME TO to=multipartIdentifier #renameTable + | ALTER (TABLE | VIEW) multipartIdentifier + SET TBLPROPERTIES tablePropertyList #setTableProperties + | ALTER (TABLE | VIEW) multipartIdentifier + UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties + | ALTER TABLE table=multipartIdentifier + (ALTER | CHANGE) COLUMN? column=multipartIdentifier + alterColumnAction? #alterTableAlterColumn + | ALTER TABLE table=multipartIdentifier partitionSpec? + CHANGE COLUMN? + colName=multipartIdentifier colType colPosition? #hiveChangeColumn + | ALTER TABLE table=multipartIdentifier partitionSpec? + REPLACE COLUMNS + '(' columns=qualifiedColTypeWithPositionList ')' #hiveReplaceColumns + | ALTER TABLE multipartIdentifier (partitionSpec)? + SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe + | ALTER TABLE multipartIdentifier (partitionSpec)? + SET SERDEPROPERTIES tablePropertyList #setTableSerDe + | ALTER (TABLE | VIEW) multipartIdentifier ADD (IF NOT EXISTS)? + partitionSpecLocation+ #addTablePartition + | ALTER TABLE multipartIdentifier + from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition + | ALTER (TABLE | VIEW) multipartIdentifier + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions + | ALTER TABLE multipartIdentifier + (partitionSpec)? SET locationSpec #setTableLocation + | ALTER TABLE multipartIdentifier RECOVER PARTITIONS #recoverPartitions + | DROP TABLE (IF EXISTS)? multipartIdentifier PURGE? #dropTable + | DROP VIEW (IF EXISTS)? multipartIdentifier #dropView + | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? + VIEW (IF NOT EXISTS)? multipartIdentifier + identifierCommentList? + (commentSpec | + (PARTITIONED ON identifierList) | + (TBLPROPERTIES tablePropertyList))* + AS query #createView + | CREATE (OR REPLACE)? GLOBAL? TEMPORARY VIEW + tableIdentifier ('(' colTypeList ')')? tableProvider + (OPTIONS tablePropertyList)? #createTempViewUsing + | ALTER VIEW multipartIdentifier AS? query #alterViewQuery + | CREATE (OR REPLACE)? TEMPORARY? FUNCTION (IF NOT EXISTS)? + multipartIdentifier AS className=STRING + (USING resource (',' resource)*)? #createFunction + | DROP TEMPORARY? FUNCTION (IF EXISTS)? multipartIdentifier #dropFunction + | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)? + statement #explain + | SHOW TABLES ((FROM | IN) multipartIdentifier)? + (LIKE? pattern=STRING)? #showTables + | SHOW TABLE EXTENDED ((FROM | IN) ns=multipartIdentifier)? + LIKE pattern=STRING partitionSpec? #showTable + | SHOW TBLPROPERTIES table=multipartIdentifier + ('(' key=tablePropertyKey ')')? #showTblProperties + | SHOW COLUMNS (FROM | IN) table=multipartIdentifier + ((FROM | IN) ns=multipartIdentifier)? #showColumns + | SHOW VIEWS ((FROM | IN) multipartIdentifier)? + (LIKE? pattern=STRING)? #showViews + | SHOW PARTITIONS multipartIdentifier partitionSpec? #showPartitions + | SHOW identifier? FUNCTIONS + (LIKE? (multipartIdentifier | pattern=STRING))? #showFunctions + | SHOW CREATE TABLE multipartIdentifier (AS SERDE)? #showCreateTable + | SHOW CURRENT NAMESPACE #showCurrentNamespace + | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction + | (DESC | DESCRIBE) namespace EXTENDED? + multipartIdentifier #describeNamespace + | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)? + multipartIdentifier partitionSpec? describeColName? #describeRelation + | (DESC | DESCRIBE) QUERY? query #describeQuery + | COMMENT ON namespace multipartIdentifier IS + comment=(STRING | NULL) #commentNamespace + | COMMENT ON TABLE multipartIdentifier IS comment=(STRING | NULL) #commentTable + | REFRESH TABLE multipartIdentifier #refreshTable + | REFRESH FUNCTION multipartIdentifier #refreshFunction + | REFRESH (STRING | .*?) #refreshResource + | CACHE LAZY? TABLE multipartIdentifier + (OPTIONS options=tablePropertyList)? (AS? query)? #cacheTable + | UNCACHE TABLE (IF EXISTS)? multipartIdentifier #uncacheTable + | CLEAR CACHE #clearCache + | LOAD DATA LOCAL? INPATH path=STRING OVERWRITE? INTO TABLE + multipartIdentifier partitionSpec? #loadData + | TRUNCATE TABLE multipartIdentifier partitionSpec? #truncateTable + | MSCK REPAIR TABLE multipartIdentifier #repairTable + | op=(ADD | LIST) identifier (STRING | .*?) #manageResource + | SET ROLE .*? #failNativeCommand + | SET TIME ZONE interval #setTimeZone + | SET TIME ZONE timezone=(STRING | LOCAL) #setTimeZone + | SET TIME ZONE .*? #setTimeZone + | SET configKey EQ configValue #setQuotedConfiguration + | SET configKey (EQ .*?)? #setQuotedConfiguration + | SET .*? EQ configValue #setQuotedConfiguration + | SET .*? #setConfiguration + | RESET configKey #resetQuotedConfiguration + | RESET .*? #resetConfiguration + | unsupportedHiveNativeCommands .*? #failNativeCommand + ; + +configKey + : quotedIdentifier + ; + +configValue + : quotedIdentifier + ; + +unsupportedHiveNativeCommands + : kw1=CREATE kw2=ROLE + | kw1=DROP kw2=ROLE + | kw1=GRANT kw2=ROLE? + | kw1=REVOKE kw2=ROLE? + | kw1=SHOW kw2=GRANT + | kw1=SHOW kw2=ROLE kw3=GRANT? + | kw1=SHOW kw2=PRINCIPALS + | kw1=SHOW kw2=ROLES + | kw1=SHOW kw2=CURRENT kw3=ROLES + | kw1=EXPORT kw2=TABLE + | kw1=IMPORT kw2=TABLE + | kw1=SHOW kw2=COMPACTIONS + | kw1=SHOW kw2=CREATE kw3=TABLE + | kw1=SHOW kw2=TRANSACTIONS + | kw1=SHOW kw2=INDEXES + | kw1=SHOW kw2=LOCKS + | kw1=CREATE kw2=INDEX + | kw1=DROP kw2=INDEX + | kw1=ALTER kw2=INDEX + | kw1=LOCK kw2=TABLE + | kw1=LOCK kw2=DATABASE + | kw1=UNLOCK kw2=TABLE + | kw1=UNLOCK kw2=DATABASE + | kw1=CREATE kw2=TEMPORARY kw3=MACRO + | kw1=DROP kw2=TEMPORARY kw3=MACRO + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=CLUSTERED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=CLUSTERED kw4=BY + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SORTED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=SKEWED kw4=BY + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SKEWED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=STORED kw5=AS kw6=DIRECTORIES + | kw1=ALTER kw2=TABLE tableIdentifier kw3=SET kw4=SKEWED kw5=LOCATION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=EXCHANGE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=ARCHIVE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=UNARCHIVE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=TOUCH + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=COMPACT + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=CONCATENATE + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=SET kw4=FILEFORMAT + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=REPLACE kw4=COLUMNS + | kw1=START kw2=TRANSACTION + | kw1=COMMIT + | kw1=ROLLBACK + | kw1=DFS + ; + +createTableHeader + : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? multipartIdentifier + ; + +replaceTableHeader + : (CREATE OR)? REPLACE TABLE multipartIdentifier + ; + +bucketSpec + : CLUSTERED BY identifierList + (SORTED BY orderedIdentifierList)? + INTO INTEGER_VALUE BUCKETS + ; + +skewSpec + : SKEWED BY identifierList + ON (constantList | nestedConstantList) + (STORED AS DIRECTORIES)? + ; + +locationSpec + : LOCATION STRING + ; + +commentSpec + : COMMENT STRING + ; + +query + : ctes? queryTerm queryOrganization + ; + +insertInto + : INSERT OVERWRITE TABLE? multipartIdentifier (partitionSpec (IF NOT EXISTS)?)? identifierList? #insertOverwriteTable + | INSERT INTO TABLE? multipartIdentifier partitionSpec? (IF NOT EXISTS)? identifierList? #insertIntoTable + | INSERT OVERWRITE LOCAL? DIRECTORY path=STRING rowFormat? createFileFormat? #insertOverwriteHiveDir + | INSERT OVERWRITE LOCAL? DIRECTORY (path=STRING)? tableProvider (OPTIONS options=tablePropertyList)? #insertOverwriteDir + ; + +partitionSpecLocation + : partitionSpec locationSpec? + ; + +partitionSpec + : PARTITION '(' partitionVal (',' partitionVal)* ')' + ; + +partitionVal + : identifier (EQ constant)? + ; + +namespace + : NAMESPACE + | DATABASE + | SCHEMA + ; + +describeFuncName + : qualifiedName + | STRING + | comparisonOperator + | arithmeticOperator + | predicateOperator + ; + +describeColName + : nameParts+=identifier ('.' nameParts+=identifier)* + ; + +ctes + : WITH namedQuery (',' namedQuery)* + ; + +namedQuery + : name=errorCapturingIdentifier (columnAliases=identifierList)? AS? '(' query ')' + ; + +tableProvider + : USING multipartIdentifier + ; + +createTableClauses + :((OPTIONS options=tablePropertyList) | + (PARTITIONED BY partitioning=partitionFieldList) | + skewSpec | + bucketSpec | + rowFormat | + createFileFormat | + locationSpec | + commentSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* + ; + +tablePropertyList + : '(' tableProperty (',' tableProperty)* ')' + ; + +tableProperty + : key=tablePropertyKey (EQ? value=tablePropertyValue)? + ; + +tablePropertyKey + : identifier ('.' identifier)* + | STRING + ; + +tablePropertyValue + : INTEGER_VALUE + | DECIMAL_VALUE + | booleanValue + | STRING + ; + +constantList + : '(' constant (',' constant)* ')' + ; + +nestedConstantList + : '(' constantList (',' constantList)* ')' + ; + +createFileFormat + : STORED AS fileFormat + | STORED BY storageHandler + ; + +fileFormat + : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING #tableFileFormat + | identifier #genericFileFormat + ; + +storageHandler + : STRING (WITH SERDEPROPERTIES tablePropertyList)? + ; + +resource + : identifier STRING + ; + +dmlStatementNoWith + : insertInto queryTerm queryOrganization #singleInsertQuery + | fromClause multiInsertQueryBody+ #multiInsertQuery + | DELETE FROM multipartIdentifier tableAlias whereClause? #deleteFromTable + | UPDATE multipartIdentifier tableAlias setClause whereClause? #updateTable + | MERGE INTO target=multipartIdentifier targetAlias=tableAlias + USING (source=multipartIdentifier | + '(' sourceQuery=query')') sourceAlias=tableAlias + ON mergeCondition=booleanExpression + matchedClause* + notMatchedClause* #mergeIntoTable + ; + +queryOrganization + : (ORDER BY order+=sortItem (',' order+=sortItem)*)? + (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)? + (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)? + (SORT BY sort+=sortItem (',' sort+=sortItem)*)? + windowClause? + (LIMIT (ALL | limit=expression))? + ; + +multiInsertQueryBody + : insertInto fromStatementBody + ; + +queryTerm + : queryPrimary #queryTermDefault + | left=queryTerm {legacy_setops_precedence_enbled}? + operator=(INTERSECT | UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation + | left=queryTerm {!legacy_setops_precedence_enbled}? + operator=INTERSECT setQuantifier? right=queryTerm #setOperation + | left=queryTerm {!legacy_setops_precedence_enbled}? + operator=(UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation + ; + +queryPrimary + : querySpecification #queryPrimaryDefault + | fromStatement #fromStmt + | TABLE multipartIdentifier #table + | inlineTable #inlineTableDefault1 + | '(' query ')' #subquery + ; + +sortItem + : expression ordering=(ASC | DESC)? (NULLS nullOrder=(LAST | FIRST))? + ; + +fromStatement + : fromClause fromStatementBody+ + ; + +fromStatementBody + : transformClause + whereClause? + queryOrganization + | selectClause + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? + queryOrganization + ; + +querySpecification + : transformClause + fromClause? + whereClause? #transformQuerySpecification + | selectClause + fromClause? + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? #regularQuerySpecification + ; + +transformClause + : (SELECT kind=TRANSFORM '(' namedExpressionSeq ')' + | kind=MAP namedExpressionSeq + | kind=REDUCE namedExpressionSeq) + inRowFormat=rowFormat? + (RECORDWRITER recordWriter=STRING)? + USING script=STRING + (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))? + outRowFormat=rowFormat? + (RECORDREADER recordReader=STRING)? + ; + +selectClause + : SELECT (hints+=hint)* setQuantifier? namedExpressionSeq + ; + +setClause + : SET assignmentList + ; + +matchedClause + : WHEN MATCHED (AND matchedCond=booleanExpression)? THEN matchedAction + ; +notMatchedClause + : WHEN NOT MATCHED (AND notMatchedCond=booleanExpression)? THEN notMatchedAction + ; + +matchedAction + : DELETE + | UPDATE SET ASTERISK + | UPDATE SET assignmentList + ; + +notMatchedAction + : INSERT ASTERISK + | INSERT '(' columns=multipartIdentifierList ')' + VALUES '(' expression (',' expression)* ')' + ; + +assignmentList + : assignment (',' assignment)* + ; + +assignment + : key=multipartIdentifier EQ value=expression + ; + +whereClause + : WHERE booleanExpression + ; + +havingClause + : HAVING booleanExpression + ; + +hint + : '/*+' hintStatements+=hintStatement (','? hintStatements+=hintStatement)* '*/' + ; + +hintStatement + : hintName=identifier + | hintName=identifier '(' parameters+=primaryExpression (',' parameters+=primaryExpression)* ')' + ; + +fromClause + : FROM relation (',' relation)* lateralView* pivotClause? + ; + +temporalClause + : FOR? (SYSTEM_TIME | TIMESTAMP) AS OF timestamp=valueExpression + | FOR? (SYSTEM_VERSION | VERSION) AS OF version=(INTEGER_VALUE | STRING) + ; + +aggregationClause + : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* ( + WITH kind=ROLLUP + | WITH kind=CUBE + | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')? + | GROUP BY kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')' + ; + +groupingSet + : '(' (expression (',' expression)*)? ')' + | expression + ; + +pivotClause + : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn IN '(' pivotValues+=pivotValue (',' pivotValues+=pivotValue)* ')' ')' + ; + +pivotColumn + : identifiers+=identifier + | '(' identifiers+=identifier (',' identifiers+=identifier)* ')' + ; + +pivotValue + : expression (AS? identifier)? + ; + +lateralView + : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? + ; + +setQuantifier + : DISTINCT + | ALL + ; + +relation + : relationPrimary joinRelation* + ; + +joinRelation + : (joinType) JOIN right=relationPrimary joinCriteria? + | NATURAL joinType JOIN right=relationPrimary + ; + +joinType + : INNER? + | CROSS + | LEFT OUTER? + | LEFT? SEMI + | RIGHT OUTER? + | FULL OUTER? + | LEFT? ANTI + ; + +joinCriteria + : ON booleanExpression + | USING identifierList + ; + +sample + : TABLESAMPLE '(' sampleMethod? ')' + ; + +sampleMethod + : negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) PERCENTLIT #sampleByPercentile + | expression ROWS #sampleByRows + | sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE + (ON (identifier | qualifiedName '(' ')'))? #sampleByBucket + | bytes=expression #sampleByBytes + ; + +identifierList + : '(' identifierSeq ')' + ; + +identifierSeq + : ident+=errorCapturingIdentifier (',' ident+=errorCapturingIdentifier)* + ; + +orderedIdentifierList + : '(' orderedIdentifier (',' orderedIdentifier)* ')' + ; + +orderedIdentifier + : ident=errorCapturingIdentifier ordering=(ASC | DESC)? + ; + +identifierCommentList + : '(' identifierComment (',' identifierComment)* ')' + ; + +identifierComment + : identifier commentSpec? + ; + +relationPrimary + : multipartIdentifier temporalClause? + sample? tableAlias #tableName + | '(' query ')' sample? tableAlias #aliasedQuery + | '(' relation ')' sample? tableAlias #aliasedRelation + | inlineTable #inlineTableDefault2 + | functionTable #tableValuedFunction + ; + +inlineTable + : VALUES expression (',' expression)* tableAlias + ; + +functionTable + : funcName=errorCapturingIdentifier '(' (expression (',' expression)*)? ')' tableAlias + ; + +tableAlias + : (AS? strictIdentifier identifierList?)? + ; + +rowFormat + : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde + | ROW FORMAT DELIMITED + (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)? + (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)? + (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)? + (LINES TERMINATED BY linesSeparatedBy=STRING)? + (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited + ; + +multipartIdentifierList + : multipartIdentifier (',' multipartIdentifier)* + ; + +multipartIdentifier + : parts+=errorCapturingIdentifier ('.' parts+=errorCapturingIdentifier)* + ; + +tableIdentifier + : (db=errorCapturingIdentifier '.')? table=errorCapturingIdentifier + ; + +functionIdentifier + : (db=errorCapturingIdentifier '.')? function=errorCapturingIdentifier + ; + +namedExpression + : expression (AS? (name=errorCapturingIdentifier | identifierList))? + ; + +namedExpressionSeq + : namedExpression (',' namedExpression)* + ; + +partitionFieldList + : '(' fields+=partitionField (',' fields+=partitionField)* ')' + ; + +partitionField + : transform #partitionTransform + | colType #partitionColumn + ; + +transform + : qualifiedName #identityTransform + | transformName=identifier + '(' argument+=transformArgument (',' argument+=transformArgument)* ')' #applyTransform + ; + +transformArgument + : qualifiedName + | constant + ; + +expression + : booleanExpression + ; + +booleanExpression + : NOT booleanExpression #logicalNot + | EXISTS '(' query ')' #exists + | valueExpression predicate? #predicated + | left=booleanExpression operator=AND right=booleanExpression #logicalBinary + | left=booleanExpression operator=OR right=booleanExpression #logicalBinary + ; + +predicate + : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression + | NOT? kind=IN '(' expression (',' expression)* ')' + | NOT? kind=IN '(' query ')' + | NOT? kind=RLIKE pattern=valueExpression + | NOT? kind=LIKE quantifier=(ANY | SOME | ALL) ('('')' | '(' expression (',' expression)* ')') + | NOT? kind=LIKE pattern=valueExpression (ESCAPE escapeChar=STRING)? + | IS NOT? kind=NULL + | IS NOT? kind=(TRUE | FALSE | UNKNOWN) + | IS NOT? kind=DISTINCT FROM right=valueExpression + ; + +valueExpression + : primaryExpression #valueExpressionDefault + | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary + | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary + | left=valueExpression operator=(PLUS | MINUS | CONCAT_PIPE) right=valueExpression #arithmeticBinary + | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary + | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary + | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary + | left=valueExpression comparisonOperator right=valueExpression #comparison + ; + +primaryExpression + : name=(CURRENT_DATE | CURRENT_TIMESTAMP) #currentDatetime + | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase + | CAST '(' expression AS dataType ')' #cast + | STRUCT '(' (argument+=namedExpression (',' argument+=namedExpression)*)? ')' #struct + | FIRST '(' expression (IGNORE NULLS)? ')' #first + | LAST '(' expression (IGNORE NULLS)? ')' #last + | POSITION '(' substr=valueExpression IN str=valueExpression ')' #position + | constant #constantDefault + | ASTERISK #star + | qualifiedName '.' ASTERISK #star + | '(' namedExpression (',' namedExpression)+ ')' #rowConstructor + | '(' query ')' #subqueryExpression + | functionName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')' + (FILTER '(' WHERE where=booleanExpression ')')? (OVER windowSpec)? #functionCall + | identifier '->' expression #lambda + | '(' identifier (',' identifier)+ ')' '->' expression #lambda + | value=primaryExpression '[' index=valueExpression ']' #subscript + | identifier #columnReference + | base=primaryExpression '.' fieldName=identifier #dereference + | '(' expression ')' #parenthesizedExpression + | EXTRACT '(' field=identifier FROM source=valueExpression ')' #extract + | (SUBSTR | SUBSTRING) '(' str=valueExpression (FROM | ',') pos=valueExpression + ((FOR | ',') len=valueExpression)? ')' #substring + | TRIM '(' trimOption=(BOTH | LEADING | TRAILING)? (trimStr=valueExpression)? + FROM srcStr=valueExpression ')' #trim + | OVERLAY '(' input=valueExpression PLACING replace=valueExpression + FROM position=valueExpression (FOR length=valueExpression)? ')' #overlay + ; + +constant + : NULL #nullLiteral + | interval #intervalLiteral + | identifier STRING #typeConstructor + | number #numericLiteral + | booleanValue #booleanLiteral + | STRING+ #stringLiteral + ; + +comparisonOperator + : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ + ; + +arithmeticOperator + : PLUS | MINUS | ASTERISK | SLASH | PERCENT | DIV | TILDE | AMPERSAND | PIPE | CONCAT_PIPE | HAT + ; + +predicateOperator + : OR | AND | IN | NOT + ; + +booleanValue + : TRUE | FALSE + ; + +interval + : INTERVAL (errorCapturingMultiUnitsInterval | errorCapturingUnitToUnitInterval)? + ; + +errorCapturingMultiUnitsInterval + : multiUnitsInterval unitToUnitInterval? + ; + +multiUnitsInterval + : (intervalValue unit+=identifier)+ + ; + +errorCapturingUnitToUnitInterval + : body=unitToUnitInterval (error1=multiUnitsInterval | error2=unitToUnitInterval)? + ; + +unitToUnitInterval + : value=intervalValue from=identifier TO to=identifier + ; + +intervalValue + : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE) + | STRING + ; + +colPosition + : position=FIRST | position=AFTER afterCol=errorCapturingIdentifier + ; + +dataType + : complex=ARRAY '<' dataType '>' #complexDataType + | complex=MAP '<' dataType ',' dataType '>' #complexDataType + | complex=STRUCT ('<' complexColTypeList? '>' | NEQ) #complexDataType + | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType + ; + +qualifiedColTypeWithPositionList + : qualifiedColTypeWithPosition (',' qualifiedColTypeWithPosition)* + ; + +qualifiedColTypeWithPosition + : name=multipartIdentifier dataType (NOT NULL)? commentSpec? colPosition? + ; + +colTypeList + : colType (',' colType)* + ; + +colType + : colName=errorCapturingIdentifier dataType (NOT NULL)? commentSpec? + ; + +complexColTypeList + : complexColType (',' complexColType)* + ; + +complexColType + : identifier ':' dataType (NOT NULL)? commentSpec? + ; + +whenClause + : WHEN condition=expression THEN result=expression + ; + +windowClause + : WINDOW namedWindow (',' namedWindow)* + ; + +namedWindow + : name=errorCapturingIdentifier AS windowSpec + ; + +windowSpec + : name=errorCapturingIdentifier #windowRef + | '('name=errorCapturingIdentifier')' #windowRef + | '(' + ( CLUSTER BY partition+=expression (',' partition+=expression)* + | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)? + ((ORDER | SORT) BY sortItem (',' sortItem)*)?) + windowFrame? + ')' #windowDef + ; + +windowFrame + : frameType=RANGE start=frameBound + | frameType=ROWS start=frameBound + | frameType=RANGE BETWEEN start=frameBound AND end=frameBound + | frameType=ROWS BETWEEN start=frameBound AND end=frameBound + ; + +frameBound + : UNBOUNDED boundType=(PRECEDING | FOLLOWING) + | boundType=CURRENT ROW + | expression boundType=(PRECEDING | FOLLOWING) + ; + +qualifiedNameList + : qualifiedName (',' qualifiedName)* + ; + +functionName + : qualifiedName + | FILTER + | LEFT + | RIGHT + ; + +qualifiedName + : identifier ('.' identifier)* + ; + +// this rule is used for explicitly capturing wrong identifiers such as test-table, which should actually be `test-table` +// replace identifier with errorCapturingIdentifier where the immediate follow symbol is not an expression, otherwise +// valid expressions such as "a-b" can be recognized as an identifier +errorCapturingIdentifier + : identifier errorCapturingIdentifierExtra + ; + +// extra left-factoring grammar +errorCapturingIdentifierExtra + : (MINUS identifier)+ #errorIdent + | #realIdent + ; + +identifier + : strictIdentifier + | {!SQL_standard_keyword_behavior}? strictNonReserved + ; + +strictIdentifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | {SQL_standard_keyword_behavior}? ansiNonReserved #unquotedIdentifier + | {!SQL_standard_keyword_behavior}? nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +number + : {!legacy_exponent_literal_as_decimal_enabled}? MINUS? EXPONENT_VALUE #exponentLiteral + | {!legacy_exponent_literal_as_decimal_enabled}? MINUS? DECIMAL_VALUE #decimalLiteral + | {legacy_exponent_literal_as_decimal_enabled}? MINUS? (EXPONENT_VALUE | DECIMAL_VALUE) #legacyDecimalLiteral + | MINUS? INTEGER_VALUE #integerLiteral + | MINUS? BIGINT_LITERAL #bigIntLiteral + | MINUS? SMALLINT_LITERAL #smallIntLiteral + | MINUS? TINYINT_LITERAL #tinyIntLiteral + | MINUS? DOUBLE_LITERAL #doubleLiteral + | MINUS? FLOAT_LITERAL #floatLiteral + | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral + ; + +alterColumnAction + : TYPE dataType + | commentSpec + | colPosition + | setOrDrop=(SET | DROP) NOT NULL + ; + +// When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. +// - Reserved keywords: +// Keywords that are reserved and can't be used as identifiers for table, view, column, +// function, alias, etc. +// - Non-reserved keywords: +// Keywords that have a special meaning only in particular contexts and can be used as +// identifiers in other contexts. For example, `EXPLAIN SELECT ...` is a command, but EXPLAIN +// can be used as identifiers in other places. +// You can find the full keywords list by searching "Start of the keywords list" in this file. +// The non-reserved keywords are listed below. Keywords not in this list are reserved keywords. +ansiNonReserved +//--ANSI-NON-RESERVED-START + : ADD + | AFTER + | ALTER + | ANALYZE + | ANTI + | ARCHIVE + | ARRAY + | ASC + | AT + | BETWEEN + | BUCKET + | BUCKETS + | BY + | CACHE + | CASCADE + | CHANGE + | CLEAR + | CLUSTER + | CLUSTERED + | CODEGEN + | COLLECTION + | COLUMNS + | COMMENT + | COMMIT + | COMPACT + | COMPACTIONS + | COMPUTE + | CONCATENATE + | COST + | CUBE + | CURRENT + | DATA + | DATABASE + | DATABASES + | DBPROPERTIES + | DEFINED + | DELETE + | DELIMITED + | DESC + | DESCRIBE + | DFS + | DIRECTORIES + | DIRECTORY + | DISTRIBUTE + | DIV + | DROP + | ESCAPED + | EXCHANGE + | EXISTS + | EXPLAIN + | EXPORT + | EXTENDED + | EXTERNAL + | EXTRACT + | FIELDS + | FILEFORMAT + | FIRST + | FOLLOWING + | FORMAT + | FORMATTED + | FUNCTION + | FUNCTIONS + | GLOBAL + | GROUPING + | IF + | IGNORE + | IMPORT + | INDEX + | INDEXES + | INPATH + | INPUTFORMAT + | INSERT + | INTERVAL + | ITEMS + | KEYS + | LAST + | LATERAL + | LAZY + | LIKE + | LIMIT + | LINES + | LIST + | LOAD + | LOCAL + | LOCATION + | LOCK + | LOCKS + | LOGICAL + | MACRO + | MAP + | MATCHED + | MERGE + | MSCK + | NAMESPACE + | NAMESPACES + | NO + | NULLS + | OF + | OPTION + | OPTIONS + | OUT + | OUTPUTFORMAT + | OVER + | OVERLAY + | OVERWRITE + | PARTITION + | PARTITIONED + | PARTITIONS + | PERCENTLIT + | PIVOT + | PLACING + | POSITION + | PRECEDING + | PRINCIPALS + | PROPERTIES + | PURGE + | QUERY + | RANGE + | RECORDREADER + | RECORDWRITER + | RECOVER + | REDUCE + | REFRESH + | RENAME + | REPAIR + | REPLACE + | RESET + | RESTRICT + | REVOKE + | RLIKE + | ROLE + | ROLES + | ROLLBACK + | ROLLUP + | ROW + | ROWS + | SCHEMA + | SEMI + | SEPARATED + | SERDE + | SERDEPROPERTIES + | SET + | SETMINUS + | SETS + | SHOW + | SKEWED + | SORT + | SORTED + | START + | STATISTICS + | STORED + | STRATIFY + | STRUCT + | SUBSTR + | SUBSTRING + | TABLES + | TABLESAMPLE + | TBLPROPERTIES + | TEMPORARY + | TERMINATED + | TOUCH + | TRANSACTION + | TRANSACTIONS + | TRANSFORM + | TRIM + | TRUE + | TRUNCATE + | TYPE + | UNARCHIVE + | UNBOUNDED + | UNCACHE + | UNLOCK + | UNSET + | UPDATE + | USE + | VALUES + | VIEW + | VIEWS + | WINDOW + | ZONE +//--ANSI-NON-RESERVED-END + ; + +// When `SQL_standard_keyword_behavior=false`, there are 2 kinds of keywords in Spark SQL. +// - Non-reserved keywords: +// Same definition as the one when `SQL_standard_keyword_behavior=true`. +// - Strict-non-reserved keywords: +// A strict version of non-reserved keywords, which can not be used as table alias. +// You can find the full keywords list by searching "Start of the keywords list" in this file. +// The strict-non-reserved keywords are listed in `strictNonReserved`. +// The non-reserved keywords are listed in `nonReserved`. +// These 2 together contain all the keywords. +strictNonReserved + : ANTI + | CROSS + | EXCEPT + | FULL + | INNER + | INTERSECT + | JOIN + | LEFT + | NATURAL + | ON + | RIGHT + | SEMI + | SETMINUS + | UNION + | USING + ; + +nonReserved +//--DEFAULT-NON-RESERVED-START + : ADD + | AFTER + | ALL + | ALTER + | ANALYZE + | AND + | ANY + | ARCHIVE + | ARRAY + | AS + | ASC + | AT + | AUTHORIZATION + | BETWEEN + | BOTH + | BUCKET + | BUCKETS + | BY + | CACHE + | CASCADE + | CASE + | CAST + | CHANGE + | CHECK + | CLEAR + | CLUSTER + | CLUSTERED + | CODEGEN + | COLLATE + | COLLECTION + | COLUMN + | COLUMNS + | COMMENT + | COMMIT + | COMPACT + | COMPACTIONS + | COMPUTE + | CONCATENATE + | CONSTRAINT + | COST + | CREATE + | CUBE + | CURRENT + | CURRENT_DATE + | CURRENT_TIME + | CURRENT_TIMESTAMP + | CURRENT_USER + | DATA + | DATABASE + | DATABASES + | DBPROPERTIES + | DEFINED + | DELETE + | DELIMITED + | DESC + | DESCRIBE + | DFS + | DIRECTORIES + | DIRECTORY + | DISTINCT + | DISTRIBUTE + | DIV + | DROP + | ELSE + | END + | ESCAPE + | ESCAPED + | EXCHANGE + | EXISTS + | EXPLAIN + | EXPORT + | EXTENDED + | EXTERNAL + | EXTRACT + | FALSE + | FETCH + | FILTER + | FIELDS + | FILEFORMAT + | FIRST + | FOLLOWING + | FOR + | FOREIGN + | FORMAT + | FORMATTED + | FROM + | FUNCTION + | FUNCTIONS + | GLOBAL + | GRANT + | GROUP + | GROUPING + | HAVING + | IF + | IGNORE + | IMPORT + | IN + | INDEX + | INDEXES + | INPATH + | INPUTFORMAT + | INSERT + | INTERVAL + | INTO + | IS + | ITEMS + | KEYS + | LAST + | LATERAL + | LAZY + | LEADING + | LIKE + | LIMIT + | LINES + | LIST + | LOAD + | LOCAL + | LOCATION + | LOCK + | LOCKS + | LOGICAL + | MACRO + | MAP + | MATCHED + | MERGE + | MSCK + | NAMESPACE + | NAMESPACES + | NO + | NOT + | NULL + | NULLS + | OF + | ONLY + | OPTION + | OPTIONS + | OR + | ORDER + | OUT + | OUTER + | OUTPUTFORMAT + | OVER + | OVERLAPS + | OVERLAY + | OVERWRITE + | PARTITION + | PARTITIONED + | PARTITIONS + | PERCENTLIT + | PIVOT + | PLACING + | POSITION + | PRECEDING + | PRIMARY + | PRINCIPALS + | PROPERTIES + | PURGE + | QUERY + | RANGE + | RECORDREADER + | RECORDWRITER + | RECOVER + | REDUCE + | REFERENCES + | REFRESH + | RENAME + | REPAIR + | REPLACE + | RESET + | RESTRICT + | REVOKE + | RLIKE + | ROLE + | ROLES + | ROLLBACK + | ROLLUP + | ROW + | ROWS + | SCHEMA + | SELECT + | SEPARATED + | SERDE + | SERDEPROPERTIES + | SESSION_USER + | SET + | SETS + | SHOW + | SKEWED + | SOME + | SORT + | SORTED + | START + | STATISTICS + | STORED + | STRATIFY + | STRUCT + | SUBSTR + | SUBSTRING + | TABLE + | TABLES + | TABLESAMPLE + | TBLPROPERTIES + | TEMPORARY + | TERMINATED + | THEN + | TIME + | TO + | TOUCH + | TRAILING + | TRANSACTION + | TRANSACTIONS + | TRANSFORM + | TRIM + | TRUE + | TRUNCATE + | TYPE + | UNARCHIVE + | UNBOUNDED + | UNCACHE + | UNIQUE + | UNKNOWN + | UNLOCK + | UNSET + | UPDATE + | USE + | USER + | VALUES + | VIEW + | VIEWS + | WHEN + | WHERE + | WINDOW + | WITH + | ZONE + | SYSTEM_VERSION + | VERSION + | SYSTEM_TIME + | TIMESTAMP +//--DEFAULT-NON-RESERVED-END + ; + +// NOTE: If you add a new token in the list below, you should update the list of keywords +// and reserved tag in `docs/sql-ref-ansi-compliance.md#sql-keywords`. + +//============================ +// Start of the keywords list +//============================ +//--SPARK-KEYWORD-LIST-START +ADD: 'ADD'; +AFTER: 'AFTER'; +ALL: 'ALL'; +ALTER: 'ALTER'; +ANALYZE: 'ANALYZE'; +AND: 'AND'; +ANTI: 'ANTI'; +ANY: 'ANY'; +ARCHIVE: 'ARCHIVE'; +ARRAY: 'ARRAY'; +AS: 'AS'; +ASC: 'ASC'; +AT: 'AT'; +AUTHORIZATION: 'AUTHORIZATION'; +BETWEEN: 'BETWEEN'; +BOTH: 'BOTH'; +BUCKET: 'BUCKET'; +BUCKETS: 'BUCKETS'; +BY: 'BY'; +CACHE: 'CACHE'; +CASCADE: 'CASCADE'; +CASE: 'CASE'; +CAST: 'CAST'; +CHANGE: 'CHANGE'; +CHECK: 'CHECK'; +CLEAR: 'CLEAR'; +CLUSTER: 'CLUSTER'; +CLUSTERED: 'CLUSTERED'; +CODEGEN: 'CODEGEN'; +COLLATE: 'COLLATE'; +COLLECTION: 'COLLECTION'; +COLUMN: 'COLUMN'; +COLUMNS: 'COLUMNS'; +COMMENT: 'COMMENT'; +COMMIT: 'COMMIT'; +COMPACT: 'COMPACT'; +COMPACTIONS: 'COMPACTIONS'; +COMPUTE: 'COMPUTE'; +CONCATENATE: 'CONCATENATE'; +CONSTRAINT: 'CONSTRAINT'; +COST: 'COST'; +CREATE: 'CREATE'; +CROSS: 'CROSS'; +CUBE: 'CUBE'; +CURRENT: 'CURRENT'; +CURRENT_DATE: 'CURRENT_DATE'; +CURRENT_TIME: 'CURRENT_TIME'; +CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +CURRENT_USER: 'CURRENT_USER'; +DATA: 'DATA'; +DATABASE: 'DATABASE'; +DATABASES: 'DATABASES' | 'SCHEMAS'; +DBPROPERTIES: 'DBPROPERTIES'; +DEFINED: 'DEFINED'; +DELETE: 'DELETE'; +DELIMITED: 'DELIMITED'; +DESC: 'DESC'; +DESCRIBE: 'DESCRIBE'; +DFS: 'DFS'; +DIRECTORIES: 'DIRECTORIES'; +DIRECTORY: 'DIRECTORY'; +DISTINCT: 'DISTINCT'; +DISTRIBUTE: 'DISTRIBUTE'; +DIV: 'DIV'; +DROP: 'DROP'; +ELSE: 'ELSE'; +END: 'END'; +ESCAPE: 'ESCAPE'; +ESCAPED: 'ESCAPED'; +EXCEPT: 'EXCEPT'; +EXCHANGE: 'EXCHANGE'; +EXISTS: 'EXISTS'; +EXPLAIN: 'EXPLAIN'; +EXPORT: 'EXPORT'; +EXTENDED: 'EXTENDED'; +EXTERNAL: 'EXTERNAL'; +EXTRACT: 'EXTRACT'; +FALSE: 'FALSE'; +FETCH: 'FETCH'; +FIELDS: 'FIELDS'; +FILTER: 'FILTER'; +FILEFORMAT: 'FILEFORMAT'; +FIRST: 'FIRST'; +FOLLOWING: 'FOLLOWING'; +FOR: 'FOR'; +FOREIGN: 'FOREIGN'; +FORMAT: 'FORMAT'; +FORMATTED: 'FORMATTED'; +FROM: 'FROM'; +FULL: 'FULL'; +FUNCTION: 'FUNCTION'; +FUNCTIONS: 'FUNCTIONS'; +GLOBAL: 'GLOBAL'; +GRANT: 'GRANT'; +GROUP: 'GROUP'; +GROUPING: 'GROUPING'; +HAVING: 'HAVING'; +IF: 'IF'; +IGNORE: 'IGNORE'; +IMPORT: 'IMPORT'; +IN: 'IN'; +INDEX: 'INDEX'; +INDEXES: 'INDEXES'; +INNER: 'INNER'; +INPATH: 'INPATH'; +INPUTFORMAT: 'INPUTFORMAT'; +INSERT: 'INSERT'; +INTERSECT: 'INTERSECT'; +INTERVAL: 'INTERVAL'; +INTO: 'INTO'; +IS: 'IS'; +ITEMS: 'ITEMS'; +JOIN: 'JOIN'; +KEYS: 'KEYS'; +LAST: 'LAST'; +LATERAL: 'LATERAL'; +LAZY: 'LAZY'; +LEADING: 'LEADING'; +LEFT: 'LEFT'; +LIKE: 'LIKE'; +LIMIT: 'LIMIT'; +LINES: 'LINES'; +LIST: 'LIST'; +LOAD: 'LOAD'; +LOCAL: 'LOCAL'; +LOCATION: 'LOCATION'; +LOCK: 'LOCK'; +LOCKS: 'LOCKS'; +LOGICAL: 'LOGICAL'; +MACRO: 'MACRO'; +MAP: 'MAP'; +MATCHED: 'MATCHED'; +MERGE: 'MERGE'; +MSCK: 'MSCK'; +NAMESPACE: 'NAMESPACE'; +NAMESPACES: 'NAMESPACES'; +NATURAL: 'NATURAL'; +NO: 'NO'; +NOT: 'NOT' | '!'; +NULL: 'NULL'; +NULLS: 'NULLS'; +OF: 'OF'; +ON: 'ON'; +ONLY: 'ONLY'; +OPTION: 'OPTION'; +OPTIONS: 'OPTIONS'; +OR: 'OR'; +ORDER: 'ORDER'; +OUT: 'OUT'; +OUTER: 'OUTER'; +OUTPUTFORMAT: 'OUTPUTFORMAT'; +OVER: 'OVER'; +OVERLAPS: 'OVERLAPS'; +OVERLAY: 'OVERLAY'; +OVERWRITE: 'OVERWRITE'; +PARTITION: 'PARTITION'; +PARTITIONED: 'PARTITIONED'; +PARTITIONS: 'PARTITIONS'; +PERCENTLIT: 'PERCENT'; +PIVOT: 'PIVOT'; +PLACING: 'PLACING'; +POSITION: 'POSITION'; +PRECEDING: 'PRECEDING'; +PRIMARY: 'PRIMARY'; +PRINCIPALS: 'PRINCIPALS'; +PROPERTIES: 'PROPERTIES'; +PURGE: 'PURGE'; +QUERY: 'QUERY'; +RANGE: 'RANGE'; +RECORDREADER: 'RECORDREADER'; +RECORDWRITER: 'RECORDWRITER'; +RECOVER: 'RECOVER'; +REDUCE: 'REDUCE'; +REFERENCES: 'REFERENCES'; +REFRESH: 'REFRESH'; +RENAME: 'RENAME'; +REPAIR: 'REPAIR'; +REPLACE: 'REPLACE'; +RESET: 'RESET'; +RESTRICT: 'RESTRICT'; +REVOKE: 'REVOKE'; +RIGHT: 'RIGHT'; +RLIKE: 'RLIKE' | 'REGEXP'; +ROLE: 'ROLE'; +ROLES: 'ROLES'; +ROLLBACK: 'ROLLBACK'; +ROLLUP: 'ROLLUP'; +ROW: 'ROW'; +ROWS: 'ROWS'; +SCHEMA: 'SCHEMA'; +SELECT: 'SELECT'; +SEMI: 'SEMI'; +SEPARATED: 'SEPARATED'; +SERDE: 'SERDE'; +SERDEPROPERTIES: 'SERDEPROPERTIES'; +SESSION_USER: 'SESSION_USER'; +SET: 'SET'; +SETMINUS: 'MINUS'; +SETS: 'SETS'; +SHOW: 'SHOW'; +SKEWED: 'SKEWED'; +SOME: 'SOME'; +SORT: 'SORT'; +SORTED: 'SORTED'; +START: 'START'; +STATISTICS: 'STATISTICS'; +STORED: 'STORED'; +STRATIFY: 'STRATIFY'; +STRUCT: 'STRUCT'; +SUBSTR: 'SUBSTR'; +SUBSTRING: 'SUBSTRING'; +TABLE: 'TABLE'; +TABLES: 'TABLES'; +TABLESAMPLE: 'TABLESAMPLE'; +TBLPROPERTIES: 'TBLPROPERTIES'; +TEMPORARY: 'TEMPORARY' | 'TEMP'; +TERMINATED: 'TERMINATED'; +THEN: 'THEN'; +TIME: 'TIME'; +TO: 'TO'; +TOUCH: 'TOUCH'; +TRAILING: 'TRAILING'; +TRANSACTION: 'TRANSACTION'; +TRANSACTIONS: 'TRANSACTIONS'; +TRANSFORM: 'TRANSFORM'; +TRIM: 'TRIM'; +TRUE: 'TRUE'; +TRUNCATE: 'TRUNCATE'; +TYPE: 'TYPE'; +UNARCHIVE: 'UNARCHIVE'; +UNBOUNDED: 'UNBOUNDED'; +UNCACHE: 'UNCACHE'; +UNION: 'UNION'; +UNIQUE: 'UNIQUE'; +UNKNOWN: 'UNKNOWN'; +UNLOCK: 'UNLOCK'; +UNSET: 'UNSET'; +UPDATE: 'UPDATE'; +USE: 'USE'; +USER: 'USER'; +USING: 'USING'; +VALUES: 'VALUES'; +VIEW: 'VIEW'; +VIEWS: 'VIEWS'; +WHEN: 'WHEN'; +WHERE: 'WHERE'; +WINDOW: 'WINDOW'; +WITH: 'WITH'; +ZONE: 'ZONE'; + +SYSTEM_VERSION: 'SYSTEM_VERSION'; +VERSION: 'VERSION'; +SYSTEM_TIME: 'SYSTEM_TIME'; +TIMESTAMP: 'TIMESTAMP'; +//--SPARK-KEYWORD-LIST-END +//============================ +// End of the keywords list +//============================ + +EQ : '=' | '=='; +NSEQ: '<=>'; +NEQ : '<>'; +NEQJ: '!='; +LT : '<'; +LTE : '<=' | '!>'; +GT : '>'; +GTE : '>=' | '!<'; + +PLUS: '+'; +MINUS: '-'; +ASTERISK: '*'; +SLASH: '/'; +PERCENT: '%'; +TILDE: '~'; +AMPERSAND: '&'; +PIPE: '|'; +CONCAT_PIPE: '||'; +HAT: '^'; + +STRING + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '"' ( ~('"'|'\\') | ('\\' .) )* '"' + ; + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +EXPONENT_VALUE + : DIGIT+ EXPONENT + | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? + ; + +DECIMAL_VALUE + : DECIMAL_DIGITS {isValidDecimal()}? + ; + +FLOAT_LITERAL + : DIGIT+ EXPONENT? 'F' + | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}? + ; + +DOUBLE_LITERAL + : DIGIT+ EXPONENT? 'D' + | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? + ; + +BIGDECIMAL_LITERAL + : DIGIT+ EXPONENT? 'BD' + | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' {!isHint()}? (BRACKETED_COMMENT|.)*? '*/' -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 b/hudi-spark-datasource/hudi-spark3.1.x/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 new file mode 100644 index 0000000000000..0ee6bee1970c1 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +grammar HoodieSqlBase; + +import SqlBase; + +singleStatement + : statement EOF + ; + +statement + : query #queryStatement + | ctes? dmlStatementNoWith #dmlStatement + | createTableHeader ('(' colTypeList ')')? tableProvider? + createTableClauses + (AS? query)? #createTable + | replaceTableHeader ('(' colTypeList ')')? tableProvider? + createTableClauses + (AS? query)? #replaceTable + | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? + VIEW (IF NOT EXISTS)? multipartIdentifier + identifierCommentList? + (commentSpec | + (PARTITIONED ON identifierList) | + (TBLPROPERTIES tablePropertyList))* + AS query #createView + | ALTER VIEW multipartIdentifier AS? query #alterViewQuery + | (DESC | DESCRIBE) QUERY? query #describeQuery + | CACHE LAZY? TABLE multipartIdentifier + (OPTIONS options=tablePropertyList)? (AS? query)? #cacheTable + | .*? #passThrough + ; diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala new file mode 100644 index 0000000000000..94e18d916b9f3 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/adapter/Spark3_1Adapter.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.adapter + +import org.apache.hudi.Spark3RowSerDe +import org.apache.hudi.client.utils.SparkRowSerDe +import org.apache.hudi.spark3.internal.ReflectUtil +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Expression, Like} +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, Join, JoinHint, LogicalPlan} +import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ +import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile, Spark3ParsePartitionUtil, SparkParsePartitionUtil} +import org.apache.spark.sql.hudi.SparkAdapter +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.parser.HoodieSpark3_1ExtendedSqlParser +import org.apache.spark.sql.{Row, SparkSession} + +/** + * The adapter for spark3. + */ +class Spark3_1Adapter extends SparkAdapter { + + override def createSparkRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe = { + new Spark3RowSerDe(encoder) + } + + override def toTableIdentifier(aliasId: AliasIdentifier): TableIdentifier = { + aliasId match { + case AliasIdentifier(name, Seq(database)) => + TableIdentifier(name, Some(database)) + case AliasIdentifier(name, Seq(_, database)) => + TableIdentifier(name, Some(database)) + case AliasIdentifier(name, Seq()) => + TableIdentifier(name, None) + case _=> throw new IllegalArgumentException(s"Cannot cast $aliasId to TableIdentifier") + } + } + + override def toTableIdentifier(relation: UnresolvedRelation): TableIdentifier = { + relation.multipartIdentifier.asTableIdentifier + } + + override def createJoin(left: LogicalPlan, right: LogicalPlan, joinType: JoinType): Join = { + Join(left, right, joinType, None, JoinHint.NONE) + } + + override def isInsertInto(plan: LogicalPlan): Boolean = { + plan.isInstanceOf[InsertIntoStatement] + } + + override def getInsertIntoChildren(plan: LogicalPlan): + Option[(LogicalPlan, Map[String, Option[String]], LogicalPlan, Boolean, Boolean)] = { + plan match { + case insert: InsertIntoStatement => + Some((insert.table, insert.partitionSpec, insert.query, insert.overwrite, insert.ifPartitionNotExists)) + case _ => + None + } + } + + override def createInsertInto(table: LogicalPlan, partition: Map[String, Option[String]], + query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean): LogicalPlan = { + ReflectUtil.createInsertInto(table, partition, Seq.empty[String], query, overwrite, ifPartitionNotExists) + } + + override def createSparkParsePartitionUtil(conf: SQLConf): SparkParsePartitionUtil = { + new Spark3ParsePartitionUtil(conf) + } + + override def createLike(left: Expression, right: Expression): Expression = { + new Like(left, right) + } + + override def parseMultipartIdentifier(parser: ParserInterface, sqlText: String): Seq[String] = { + parser.parseMultipartIdentifier(sqlText) + } + + override def createExtendedSparkParser: Option[(SparkSession, ParserInterface) => ParserInterface] = { + Some( + (spark: SparkSession, delegate: ParserInterface) => new HoodieSpark3_1ExtendedSqlParser(spark, delegate) + ) + } + + /** + * Combine [[PartitionedFile]] to [[FilePartition]] according to `maxSplitBytes`. + */ + override def getFilePartitions( + sparkSession: SparkSession, + partitionedFiles: Seq[PartitionedFile], + maxSplitBytes: Long): Seq[FilePartition] = { + FilePartition.getFilePartitions(sparkSession, partitionedFiles, maxSplitBytes) + } +} diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_1ExtendedSqlAstBuilder.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_1ExtendedSqlAstBuilder.scala new file mode 100644 index 0000000000000..eefd3cf407d95 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_1ExtendedSqlAstBuilder.scala @@ -0,0 +1,3152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.parser + +import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} +import org.antlr.v4.runtime.{ParserRuleContext, Token} +import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser._ +import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseVisitor, HoodieSqlBaseParser} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, FunctionResource, FunctionResourceType} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} +import org.apache.spark.sql.catalyst.parser.ParserUtils.{EnhancedLogicalPlan, checkDuplicateClauses, checkDuplicateKeys, escapedIdentifier, operationNotAllowed, source, string, stringWithoutUnescape, validate, withOrigin} +import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} +import org.apache.spark.sql.catalyst.util.IntervalUtils.IntervalUnit +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils, truncatedString} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition +import org.apache.spark.sql.connector.catalog.{SupportsNamespaces, TableCatalog} +import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform, Expression => V2Expression} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.random.RandomSampler + +import java.util.Locale +import javax.xml.bind.DatatypeConverter +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +/** + * The AstBuilder for HoodieSqlParser to parser the AST tree to Logical Plan. + * Here we only do the parser for the extended sql syntax. e.g MergeInto. For + * other sql syntax we use the delegate sql parser which is the SparkSqlParser. + */ +class HoodieSpark3_1ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterface) + extends HoodieSqlBaseBaseVisitor[AnyRef] with Logging { + + protected def typedVisit[T](ctx: ParseTree): T = { + ctx.accept(this).asInstanceOf[T] + } + + /** + * Override the default behavior for all visit methods. This will only return a non-null result + * when the context has only one child. This is done because there is no generic method to + * combine the results of the context children. In all other cases null is returned. + */ + override def visitChildren(node: RuleNode): AnyRef = { + if (node.getChildCount == 1) { + node.getChild(0).accept(this) + } else { + null + } + } + + /** + * Create an aliased table reference. This is typically used in FROM clauses. + */ + override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { + val tableId = visitMultipartIdentifier(ctx.multipartIdentifier()) + val relation = UnresolvedRelation(tableId) + val table = mayApplyAliasPlan( + ctx.tableAlias, relation.optionalMap(ctx.temporalClause)(withTimeTravel)) + table.optionalMap(ctx.sample)(withSample) + } + + private def withTimeTravel( + ctx: TemporalClauseContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val v = ctx.version + val version = if (ctx.INTEGER_VALUE != null) { + Some(v.getText) + } else { + Option(v).map(string) + } + + val timestamp = Option(ctx.timestamp).map(expression) + if (timestamp.exists(_.references.nonEmpty)) { + throw new ParseException( + "timestamp expression cannot refer to any columns", ctx.timestamp) + } + if (timestamp.exists(e => SubqueryExpression.hasSubquery(e))) { + throw new ParseException( + "timestamp expression cannot contain subqueries", ctx.timestamp) + } + + TimeTravelRelation(plan, timestamp, version) + } + + // ============== The following code is fork from org.apache.spark.sql.catalyst.parser.AstBuilder + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { + visit(ctx.statement).asInstanceOf[LogicalPlan] + } + + override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) { + visitNamedExpression(ctx.namedExpression) + } + + override def visitSingleTableIdentifier( + ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) { + visitTableIdentifier(ctx.tableIdentifier) + } + + override def visitSingleFunctionIdentifier( + ctx: SingleFunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) { + visitFunctionIdentifier(ctx.functionIdentifier) + } + + override def visitSingleMultipartIdentifier( + ctx: SingleMultipartIdentifierContext): Seq[String] = withOrigin(ctx) { + visitMultipartIdentifier(ctx.multipartIdentifier) + } + + override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { + typedVisit[DataType](ctx.dataType) + } + + override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = { + val schema = StructType(visitColTypeList(ctx.colTypeList)) + withOrigin(ctx)(schema) + } + + /* ******************************************************************************************** + * Plan parsing + * ******************************************************************************************** */ + protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree) + + /** + * Create a top-level plan with Common Table Expressions. + */ + override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) { + val query = plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses) + + // Apply CTEs + query.optionalMap(ctx.ctes)(withCTE) + } + + override def visitDmlStatement(ctx: DmlStatementContext): AnyRef = withOrigin(ctx) { + val dmlStmt = plan(ctx.dmlStatementNoWith) + // Apply CTEs + dmlStmt.optionalMap(ctx.ctes)(withCTE) + } + + private def withCTE(ctx: CtesContext, plan: LogicalPlan): LogicalPlan = { + val ctes = ctx.namedQuery.asScala.map { nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) + } + // Check for duplicate names. + val duplicates = ctes.groupBy(_._1).filter(_._2.size > 1).keys + if (duplicates.nonEmpty) { + throw new ParseException( + s"CTE definition can't have duplicate names: ${duplicates.mkString("'", "', '", "'")}.", + ctx) + } + With(plan, ctes.toSeq) + } + + /** + * Create a logical query plan for a hive-style FROM statement body. + */ + private def withFromStatementBody( + ctx: FromStatementBodyContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // two cases for transforms and selects + if (ctx.transformClause != null) { + withTransformQuerySpecification( + ctx, + ctx.transformClause, + ctx.whereClause, + plan + ) + } else { + withSelectQuerySpecification( + ctx, + ctx.selectClause, + ctx.lateralView, + ctx.whereClause, + ctx.aggregationClause, + ctx.havingClause, + ctx.windowClause, + plan + ) + } + } + + override def visitFromStatement(ctx: FromStatementContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + val selects = ctx.fromStatementBody.asScala.map { body => + withFromStatementBody(body, from). + // Add organization statements. + optionalMap(body.queryOrganization)(withQueryResultClauses) + } + // If there are multiple SELECT just UNION them together into one query. + if (selects.length == 1) { + selects.head + } else { + Union(selects.toSeq) + } + } + + /** + * Create a named logical plan. + * + * This is only used for Common Table Expressions. + */ + override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { + val subQuery: LogicalPlan = plan(ctx.query).optionalMap(ctx.columnAliases)( + (columnAliases, plan) => + UnresolvedSubqueryColumnAliases(visitIdentifierList(columnAliases), plan) + ) + SubqueryAlias(ctx.name.getText, subQuery) + } + + /** + * Create a logical plan which allows for multiple inserts using one 'from' statement. These + * queries have the following SQL form: + * {{{ + * [WITH cte...]? + * FROM src + * [INSERT INTO tbl1 SELECT *]+ + * }}} + * For example: + * {{{ + * FROM db.tbl1 A + * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5 + * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12 + * }}} + * This (Hive) feature cannot be combined with set-operators. + */ + override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + + // Build the insert clauses. + val inserts = ctx.multiInsertQueryBody.asScala.map { body => + withInsertInto(body.insertInto, + withFromStatementBody(body.fromStatementBody, from). + optionalMap(body.fromStatementBody.queryOrganization)(withQueryResultClauses)) + } + + // If there are multiple INSERTS just UNION them together into one query. + if (inserts.length == 1) { + inserts.head + } else { + Union(inserts.toSeq) + } + } + + /** + * Create a logical plan for a regular (single-insert) query. + */ + override def visitSingleInsertQuery( + ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { + withInsertInto( + ctx.insertInto(), + plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses)) + } + + /** + * Parameters used for writing query to a table: + * (multipartIdentifier, tableColumnList, partitionKeys, ifPartitionNotExists). + */ + type InsertTableParams = (Seq[String], Seq[String], Map[String, Option[String]], Boolean) + + /** + * Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider). + */ + type InsertDirParams = (Boolean, CatalogStorageFormat, Option[String]) + + /** + * Add an + * {{{ + * INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]? [identifierList] + * INSERT INTO [TABLE] tableIdentifier [partitionSpec] [identifierList] + * INSERT OVERWRITE [LOCAL] DIRECTORY STRING [rowFormat] [createFileFormat] + * INSERT OVERWRITE [LOCAL] DIRECTORY [STRING] tableProvider [OPTIONS tablePropertyList] + * }}} + * operation to logical plan + */ + private def withInsertInto( + ctx: InsertIntoContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + ctx match { + case table: InsertIntoTableContext => + val (tableIdent, cols, partition, ifPartitionNotExists) = visitInsertIntoTable(table) + InsertIntoStatement( + UnresolvedRelation(tableIdent), + partition, + cols, + query, + overwrite = false, + ifPartitionNotExists) + case table: InsertOverwriteTableContext => + val (tableIdent, cols, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table) + InsertIntoStatement( + UnresolvedRelation(tableIdent), + partition, + cols, + query, + overwrite = true, + ifPartitionNotExists) + case dir: InsertOverwriteDirContext => + val (isLocal, storage, provider) = visitInsertOverwriteDir(dir) + InsertIntoDir(isLocal, storage, provider, query, overwrite = true) + case hiveDir: InsertOverwriteHiveDirContext => + val (isLocal, storage, provider) = visitInsertOverwriteHiveDir(hiveDir) + InsertIntoDir(isLocal, storage, provider, query, overwrite = true) + case _ => + throw new ParseException("Invalid InsertIntoContext", ctx) + } + } + + /** + * Add an INSERT INTO TABLE operation to the logical plan. + */ + override def visitInsertIntoTable( + ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) { + val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier) + val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil) + val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + + if (ctx.EXISTS != null) { + operationNotAllowed("INSERT INTO ... IF NOT EXISTS", ctx) + } + + (tableIdent, cols, partitionKeys, false) + } + + /** + * Add an INSERT OVERWRITE TABLE operation to the logical plan. + */ + override def visitInsertOverwriteTable( + ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) { + assert(ctx.OVERWRITE() != null) + val tableIdent = visitMultipartIdentifier(ctx.multipartIdentifier) + val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil) + val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + + val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty) + if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) { + operationNotAllowed("IF NOT EXISTS with dynamic partitions: " + + dynamicPartitionKeys.keys.mkString(", "), ctx) + } + + (tableIdent, cols, partitionKeys, ctx.EXISTS() != null) + } + + /** + * Write to a directory, returning a [[InsertIntoDir]] logical plan. + */ + override def visitInsertOverwriteDir( + ctx: InsertOverwriteDirContext): InsertDirParams = withOrigin(ctx) { + throw new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx) + } + + /** + * Write to a directory, returning a [[InsertIntoDir]] logical plan. + */ + override def visitInsertOverwriteHiveDir( + ctx: InsertOverwriteHiveDirContext): InsertDirParams = withOrigin(ctx) { + throw new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx) + } + + private def getTableAliasWithoutColumnAlias( + ctx: TableAliasContext, op: String): Option[String] = { + if (ctx == null) { + None + } else { + val ident = ctx.strictIdentifier() + if (ctx.identifierList() != null) { + throw new ParseException(s"Columns aliases are not allowed in $op.", ctx.identifierList()) + } + if (ident != null) Some(ident.getText) else None + } + } + + override def visitDeleteFromTable( + ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) { + val table = UnresolvedRelation(visitMultipartIdentifier(ctx.multipartIdentifier())) + val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE") + val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table) + val predicate = if (ctx.whereClause() != null) { + Some(expression(ctx.whereClause().booleanExpression())) + } else { + None + } + DeleteFromTable(aliasedTable, predicate) + } + + override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) { + val table = UnresolvedRelation(visitMultipartIdentifier(ctx.multipartIdentifier())) + val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE") + val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table) + val assignments = withAssignments(ctx.setClause().assignmentList()) + val predicate = if (ctx.whereClause() != null) { + Some(expression(ctx.whereClause().booleanExpression())) + } else { + None + } + + UpdateTable(aliasedTable, assignments, predicate) + } + + private def withAssignments(assignCtx: AssignmentListContext): Seq[Assignment] = + withOrigin(assignCtx) { + assignCtx.assignment().asScala.map { assign => + Assignment(UnresolvedAttribute(visitMultipartIdentifier(assign.key)), + expression(assign.value)) + }.toSeq + } + + override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) { + val targetTable = UnresolvedRelation(visitMultipartIdentifier(ctx.target)) + val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE") + val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable) + + val sourceTableOrQuery = if (ctx.source != null) { + UnresolvedRelation(visitMultipartIdentifier(ctx.source)) + } else if (ctx.sourceQuery != null) { + visitQuery(ctx.sourceQuery) + } else { + throw new ParseException("Empty source for merge: you should specify a source" + + " table/subquery in merge.", ctx.source) + } + val sourceTableAlias = getTableAliasWithoutColumnAlias(ctx.sourceAlias, "MERGE") + val aliasedSource = + sourceTableAlias.map(SubqueryAlias(_, sourceTableOrQuery)).getOrElse(sourceTableOrQuery) + + val mergeCondition = expression(ctx.mergeCondition) + + val matchedActions = ctx.matchedClause().asScala.map { + clause => { + if (clause.matchedAction().DELETE() != null) { + DeleteAction(Option(clause.matchedCond).map(expression)) + } else if (clause.matchedAction().UPDATE() != null) { + val condition = Option(clause.matchedCond).map(expression) + if (clause.matchedAction().ASTERISK() != null) { + UpdateAction(condition, Seq()) + } else { + UpdateAction(condition, withAssignments(clause.matchedAction().assignmentList())) + } + } else { + // It should not be here. + throw new ParseException( + s"Unrecognized matched action: ${clause.matchedAction().getText}", + clause.matchedAction()) + } + } + } + val notMatchedActions = ctx.notMatchedClause().asScala.map { + clause => { + if (clause.notMatchedAction().INSERT() != null) { + val condition = Option(clause.notMatchedCond).map(expression) + if (clause.notMatchedAction().ASTERISK() != null) { + InsertAction(condition, Seq()) + } else { + val columns = clause.notMatchedAction().columns.multipartIdentifier() + .asScala.map(attr => UnresolvedAttribute(visitMultipartIdentifier(attr))) + val values = clause.notMatchedAction().expression().asScala.map(expression) + if (columns.size != values.size) { + throw new ParseException("The number of inserted values cannot match the fields.", + clause.notMatchedAction()) + } + InsertAction(condition, columns.zip(values).map(kv => Assignment(kv._1, kv._2)).toSeq) + } + } else { + // It should not be here. + throw new ParseException( + s"Unrecognized not matched action: ${clause.notMatchedAction().getText}", + clause.notMatchedAction()) + } + } + } + if (matchedActions.isEmpty && notMatchedActions.isEmpty) { + throw new ParseException("There must be at least one WHEN clause in a MERGE statement", ctx) + } + // children being empty means that the condition is not set + val matchedActionSize = matchedActions.length + if (matchedActionSize >= 2 && !matchedActions.init.forall(_.condition.nonEmpty)) { + throw new ParseException("When there are more than one MATCHED clauses in a MERGE " + + "statement, only the last MATCHED clause can omit the condition.", ctx) + } + val notMatchedActionSize = notMatchedActions.length + if (notMatchedActionSize >= 2 && !notMatchedActions.init.forall(_.condition.nonEmpty)) { + throw new ParseException("When there are more than one NOT MATCHED clauses in a MERGE " + + "statement, only the last NOT MATCHED clause can omit the condition.", ctx) + } + + MergeIntoTable( + aliasedTarget, + aliasedSource, + mergeCondition, + matchedActions.toSeq, + notMatchedActions.toSeq) + } + + /** + * Create a partition specification map. + */ + override def visitPartitionSpec( + ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { + val legacyNullAsString = + conf.getConf(SQLConf.LEGACY_PARSE_NULL_PARTITION_SPEC_AS_STRING_LITERAL) + val parts = ctx.partitionVal.asScala.map { pVal => + val name = pVal.identifier.getText + val value = Option(pVal.constant).map(v => visitStringConstant(v, legacyNullAsString)) + name -> value + } + // Before calling `toMap`, we check duplicated keys to avoid silently ignore partition values + // in partition spec like PARTITION(a='1', b='2', a='3'). The real semantical check for + // partition columns will be done in analyzer. + if (conf.caseSensitiveAnalysis) { + checkDuplicateKeys(parts.toSeq, ctx) + } else { + checkDuplicateKeys(parts.map(kv => kv._1.toLowerCase(Locale.ROOT) -> kv._2).toSeq, ctx) + } + parts.toMap + } + + /** + * Create a partition specification map without optional values. + */ + protected def visitNonOptionalPartitionSpec( + ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { + visitPartitionSpec(ctx).map { + case (key, None) => throw new ParseException(s"Found an empty partition key '$key'.", ctx) + case (key, Some(value)) => key -> value + } + } + + /** + * Convert a constant of any type into a string. This is typically used in DDL commands, and its + * main purpose is to prevent slight differences due to back to back conversions i.e.: + * String -> Literal -> String. + */ + protected def visitStringConstant( + ctx: ConstantContext, + legacyNullAsString: Boolean): String = withOrigin(ctx) { + ctx match { + case _: NullLiteralContext if !legacyNullAsString => null + case s: StringLiteralContext => createString(s) + case o => o.getText + } + } + + /** + * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These + * clauses determine the shape (ordering/partitioning/rows) of the query result. + */ + private def withQueryResultClauses( + ctx: QueryOrganizationContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. + val withOrder = if ( + !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // ORDER BY ... + Sort(order.asScala.map(visitSortItem).toSeq, global = true, query) + } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... + Sort(sort.asScala.map(visitSortItem).toSeq, global = false, query) + } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // DISTRIBUTE BY ... + withRepartitionByExpression(ctx, expressionList(distributeBy), query) + } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... DISTRIBUTE BY ... + Sort( + sort.asScala.map(visitSortItem).toSeq, + global = false, + withRepartitionByExpression(ctx, expressionList(distributeBy), query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { + // CLUSTER BY ... + val expressions = expressionList(clusterBy) + Sort( + expressions.map(SortOrder(_, Ascending)), + global = false, + withRepartitionByExpression(ctx, expressions, query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // [EMPTY] + query + } else { + throw new ParseException( + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx) + } + + // WINDOWS + val withWindow = withOrder.optionalMap(windowClause)(withWindowClause) + + // LIMIT + // - LIMIT ALL is the same as omitting the LIMIT clause + withWindow.optional(limit) { + Limit(typedVisit(limit), withWindow) + } + } + + /** + * Create a clause for DISTRIBUTE BY. + */ + protected def withRepartitionByExpression( + ctx: QueryOrganizationContext, + expressions: Seq[Expression], + query: LogicalPlan): LogicalPlan = { + throw new ParseException("DISTRIBUTE BY is not supported", ctx) + } + + override def visitTransformQuerySpecification( + ctx: TransformQuerySpecificationContext): LogicalPlan = withOrigin(ctx) { + val from = OneRowRelation().optional(ctx.fromClause) { + visitFromClause(ctx.fromClause) + } + withTransformQuerySpecification(ctx, ctx.transformClause, ctx.whereClause, from) + } + + override def visitRegularQuerySpecification( + ctx: RegularQuerySpecificationContext): LogicalPlan = withOrigin(ctx) { + val from = OneRowRelation().optional(ctx.fromClause) { + visitFromClause(ctx.fromClause) + } + withSelectQuerySpecification( + ctx, + ctx.selectClause, + ctx.lateralView, + ctx.whereClause, + ctx.aggregationClause, + ctx.havingClause, + ctx.windowClause, + from + ) + } + + override def visitNamedExpressionSeq( + ctx: NamedExpressionSeqContext): Seq[Expression] = { + Option(ctx).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + } + + /** + * Create a logical plan using a having clause. + */ + private def withHavingClause( + ctx: HavingClauseContext, plan: LogicalPlan): LogicalPlan = { + // Note that we add a cast to non-predicate expressions. If the expression itself is + // already boolean, the optimizer will get rid of the unnecessary cast. + val predicate = expression(ctx.booleanExpression) match { + case p: Predicate => p + case e => Cast(e, BooleanType) + } + UnresolvedHaving(predicate, plan) + } + + /** + * Create a logical plan using a where clause. + */ + private def withWhereClause(ctx: WhereClauseContext, plan: LogicalPlan): LogicalPlan = { + Filter(expression(ctx.booleanExpression), plan) + } + + /** + * Add a hive-style transform (SELECT TRANSFORM/MAP/REDUCE) query specification to a logical plan. + */ + private def withTransformQuerySpecification( + ctx: ParserRuleContext, + transformClause: TransformClauseContext, + whereClause: WhereClauseContext, + relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Add where. + val withFilter = relation.optionalMap(whereClause)(withWhereClause) + + // Create the transform. + val expressions = visitNamedExpressionSeq(transformClause.namedExpressionSeq) + + // Create the attributes. + val (attributes, schemaLess) = if (transformClause.colTypeList != null) { + // Typed return columns. + (createSchema(transformClause.colTypeList).toAttributes, false) + } else if (transformClause.identifierSeq != null) { + // Untyped return columns. + val attrs = visitIdentifierSeq(transformClause.identifierSeq).map { name => + AttributeReference(name, StringType, nullable = true)() + } + (attrs, false) + } else { + (Seq(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), true) + } + + // Create the transform. + ScriptTransformation( + expressions, + string(transformClause.script), + attributes, + withFilter, + withScriptIOSchema( + ctx, + transformClause.inRowFormat, + transformClause.recordWriter, + transformClause.outRowFormat, + transformClause.recordReader, + schemaLess + ) + ) + } + + /** + * Add a regular (SELECT) query specification to a logical plan. The query specification + * is the core of the logical plan, this is where sourcing (FROM clause), projection (SELECT), + * aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place. + * + * Note that query hints are ignored (both by the parser and the builder). + */ + private def withSelectQuerySpecification( + ctx: ParserRuleContext, + selectClause: SelectClauseContext, + lateralView: java.util.List[LateralViewContext], + whereClause: WhereClauseContext, + aggregationClause: AggregationClauseContext, + havingClause: HavingClauseContext, + windowClause: WindowClauseContext, + relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Add lateral views. + val withLateralView = lateralView.asScala.foldLeft(relation)(withGenerate) + + // Add where. + val withFilter = withLateralView.optionalMap(whereClause)(withWhereClause) + + val expressions = visitNamedExpressionSeq(selectClause.namedExpressionSeq) + // Add aggregation or a project. + val namedExpressions = expressions.map { + case e: NamedExpression => e + case e: Expression => UnresolvedAlias(e) + } + + def createProject() = if (namedExpressions.nonEmpty) { + Project(namedExpressions, withFilter) + } else { + withFilter + } + + val withProject = if (aggregationClause == null && havingClause != null) { + if (conf.getConf(SQLConf.LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE)) { + // If the legacy conf is set, treat HAVING without GROUP BY as WHERE. + val predicate = expression(havingClause.booleanExpression) match { + case p: Predicate => p + case e => Cast(e, BooleanType) + } + Filter(predicate, createProject()) + } else { + // According to SQL standard, HAVING without GROUP BY means global aggregate. + withHavingClause(havingClause, Aggregate(Nil, namedExpressions, withFilter)) + } + } else if (aggregationClause != null) { + val aggregate = withAggregationClause(aggregationClause, namedExpressions, withFilter) + aggregate.optionalMap(havingClause)(withHavingClause) + } else { + // When hitting this branch, `having` must be null. + createProject() + } + + // Distinct + val withDistinct = if ( + selectClause.setQuantifier() != null && + selectClause.setQuantifier().DISTINCT() != null) { + Distinct(withProject) + } else { + withProject + } + + // Window + val withWindow = withDistinct.optionalMap(windowClause)(withWindowClause) + + // Hint + selectClause.hints.asScala.foldRight(withWindow)(withHints) + } + + /** + * Create a (Hive based) [[ScriptInputOutputSchema]]. + */ + protected def withScriptIOSchema( + ctx: ParserRuleContext, + inRowFormat: RowFormatContext, + recordWriter: Token, + outRowFormat: RowFormatContext, + recordReader: Token, + schemaLess: Boolean): ScriptInputOutputSchema = { + throw new ParseException("Script Transform is not supported", ctx) + } + + /** + * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma + * separated) relations here, these get converted into a single plan by condition-less inner join. + */ + override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { + val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) => + val right = plan(relation.relationPrimary) + val join = right.optionalMap(left)(Join(_, _, Inner, None, JoinHint.NONE)) + withJoinRelations(join, relation) + } + if (ctx.pivotClause() != null) { + if (!ctx.lateralView.isEmpty) { + throw new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx) + } + withPivot(ctx.pivotClause, from) + } else { + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } + } + + /** + * Connect two queries by a Set operator. + * + * Supported Set operators are: + * - UNION [ DISTINCT | ALL ] + * - EXCEPT [ DISTINCT | ALL ] + * - MINUS [ DISTINCT | ALL ] + * - INTERSECT [DISTINCT | ALL] + */ + override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { + val left = plan(ctx.left) + val right = plan(ctx.right) + val all = Option(ctx.setQuantifier()).exists(_.ALL != null) + ctx.operator.getType match { + case HoodieSqlBaseParser.UNION if all => + Union(left, right) + case HoodieSqlBaseParser.UNION => + Distinct(Union(left, right)) + case HoodieSqlBaseParser.INTERSECT if all => + Intersect(left, right, isAll = true) + case HoodieSqlBaseParser.INTERSECT => + Intersect(left, right, isAll = false) + case HoodieSqlBaseParser.EXCEPT if all => + Except(left, right, isAll = true) + case HoodieSqlBaseParser.EXCEPT => + Except(left, right, isAll = false) + case HoodieSqlBaseParser.SETMINUS if all => + Except(left, right, isAll = true) + case HoodieSqlBaseParser.SETMINUS => + Except(left, right, isAll = false) + } + } + + /** + * Add a [[WithWindowDefinition]] operator to a logical plan. + */ + private def withWindowClause( + ctx: WindowClauseContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Collect all window specifications defined in the WINDOW clause. + val baseWindowTuples = ctx.namedWindow.asScala.map { + wCtx => + (wCtx.name.getText, typedVisit[WindowSpec](wCtx.windowSpec)) + } + baseWindowTuples.groupBy(_._1).foreach { kv => + if (kv._2.size > 1) { + throw new ParseException(s"The definition of window '${kv._1}' is repetitive", ctx) + } + } + val baseWindowMap = baseWindowTuples.toMap + + // Handle cases like + // window w1 as (partition by p_mfgr order by p_name + // range between 2 preceding and 2 following), + // w2 as w1 + val windowMapView = baseWindowMap.mapValues { + case WindowSpecReference(name) => + baseWindowMap.get(name) match { + case Some(spec: WindowSpecDefinition) => + spec + case Some(ref) => + throw new ParseException(s"Window reference '$name' is not a window specification", ctx) + case None => + throw new ParseException(s"Cannot resolve window reference '$name'", ctx) + } + case spec: WindowSpecDefinition => spec + } + + // Note that mapValues creates a view instead of materialized map. We force materialization by + // mapping over identity. + WithWindowDefinition(windowMapView.map(identity).toMap, query) + } + + /** + * Add an [[Aggregate]] or [[GroupingSets]] to a logical plan. + */ + private def withAggregationClause( + ctx: AggregationClauseContext, + selectExpressions: Seq[NamedExpression], + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val groupByExpressions = expressionList(ctx.groupingExpressions) + + if (ctx.GROUPING != null) { + // GROUP BY .... GROUPING SETS (...) + val selectedGroupByExprs = + ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)).toSeq) + GroupingSets(selectedGroupByExprs.toSeq, groupByExpressions, query, selectExpressions) + } else { + // GROUP BY .... (WITH CUBE | WITH ROLLUP)? + val mappedGroupByExpressions = if (ctx.CUBE != null) { + Seq(Cube(groupByExpressions)) + } else if (ctx.ROLLUP != null) { + Seq(Rollup(groupByExpressions)) + } else { + groupByExpressions + } + Aggregate(mappedGroupByExpressions, selectExpressions, query) + } + } + + /** + * Add [[UnresolvedHint]]s to a logical plan. + */ + private def withHints( + ctx: HintContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + var plan = query + ctx.hintStatements.asScala.reverse.foreach { stmt => + plan = UnresolvedHint(stmt.hintName.getText, + stmt.parameters.asScala.map(expression).toSeq, plan) + } + plan + } + + /** + * Add a [[Pivot]] to a logical plan. + */ + private def withPivot( + ctx: PivotClauseContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val aggregates = Option(ctx.aggregates).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) { + UnresolvedAttribute.quoted(ctx.pivotColumn.identifier.getText) + } else { + CreateStruct( + ctx.pivotColumn.identifiers.asScala.map( + identifier => UnresolvedAttribute.quoted(identifier.getText)).toSeq) + } + val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue) + Pivot(None, pivotColumn, pivotValues.toSeq, aggregates, query) + } + + /** + * Create a Pivot column value with or without an alias. + */ + override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else { + e + } + } + + /** + * Add a [[Generate]] (Lateral View) to a logical plan. + */ + private def withGenerate( + query: LogicalPlan, + ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) { + val expressions = expressionList(ctx.expression) + Generate( + UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions), + unrequiredChildIndex = Nil, + outer = ctx.OUTER != null, + // scalastyle:off caselocale + Some(ctx.tblName.getText.toLowerCase), + // scalastyle:on caselocale + ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply).toSeq, + query) + } + + /** + * Create a single relation referenced in a FROM clause. This method is used when a part of the + * join condition is nested, for example: + * {{{ + * select * from t1 join (t2 cross join t3) on col1 = col2 + * }}} + */ + override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) { + withJoinRelations(plan(ctx.relationPrimary), ctx) + } + + /** + * Join one more [[LogicalPlan]]s to the current logical plan. + */ + private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = { + ctx.joinRelation.asScala.foldLeft(base) { (left, join) => + withOrigin(join) { + val baseJoinType = join.joinType match { + case null => Inner + case jt if jt.CROSS != null => Cross + case jt if jt.FULL != null => FullOuter + case jt if jt.SEMI != null => LeftSemi + case jt if jt.ANTI != null => LeftAnti + case jt if jt.LEFT != null => LeftOuter + case jt if jt.RIGHT != null => RightOuter + case _ => Inner + } + + // Resolve the join type and join condition + val (joinType, condition) = Option(join.joinCriteria) match { + case Some(c) if c.USING != null => + (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case Some(c) => + throw new ParseException(s"Unimplemented joinCriteria: $c", ctx) + case None if join.NATURAL != null => + if (baseJoinType == Cross) { + throw new ParseException("NATURAL CROSS JOIN is not supported", ctx) + } + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + Join(left, plan(join.right), joinType, condition, JoinHint.NONE) + } + } + } + + /** + * Add a [[Sample]] to a logical plan. + * + * This currently supports the following sampling methods: + * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. + * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages + * are defined as a number between 0 and 100. + * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction. + */ + private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Create a sampled plan if we need one. + def sample(fraction: Double): Sample = { + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + val eps = RandomSampler.roundingEpsilon + validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps, + s"Sampling fraction ($fraction) must be on interval [0, 1]", + ctx) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query) + } + + if (ctx.sampleMethod() == null) { + throw new ParseException("TABLESAMPLE does not accept empty inputs.", ctx) + } + + ctx.sampleMethod() match { + case ctx: SampleByRowsContext => + Limit(expression(ctx.expression), query) + + case ctx: SampleByPercentileContext => + val fraction = ctx.percentage.getText.toDouble + val sign = if (ctx.negativeSign == null) 1 else -1 + sample(sign * fraction / 100.0d) + + case ctx: SampleByBytesContext => + val bytesStr = ctx.bytes.getText + if (bytesStr.matches("[0-9]+[bBkKmMgG]")) { + throw new ParseException("TABLESAMPLE(byteLengthLiteral) is not supported", ctx) + } else { + throw new ParseException( + bytesStr + " is not a valid byte length literal, " + + "expected syntax: DIGIT+ ('B' | 'K' | 'M' | 'G')", ctx) + } + + case ctx: SampleByBucketContext if ctx.ON() != null => + if (ctx.identifier != null) { + throw new ParseException( + "TABLESAMPLE(BUCKET x OUT OF y ON colname) is not supported", ctx) + } else { + throw new ParseException( + "TABLESAMPLE(BUCKET x OUT OF y ON function) is not supported", ctx) + } + + case ctx: SampleByBucketContext => + sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) + } + } + + /** + * Create a logical plan for a sub-query. + */ + override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.query) + } + + /** + * Create an un-aliased table reference. This is typically used for top-level table references, + * for example: + * {{{ + * INSERT INTO db.tbl2 + * TABLE db.tbl1 + * }}} + */ + override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) { + UnresolvedRelation(visitMultipartIdentifier(ctx.multipartIdentifier)) + } + + /** + * Create a table-valued function call with arguments, e.g. range(1000) + */ + override def visitTableValuedFunction(ctx: TableValuedFunctionContext) + : LogicalPlan = withOrigin(ctx) { + val func = ctx.functionTable + val aliases = if (func.tableAlias.identifierList != null) { + visitIdentifierList(func.tableAlias.identifierList) + } else { + Seq.empty + } + + val tvf = UnresolvedTableValuedFunction( + func.funcName.getText, func.expression.asScala.map(expression).toSeq, aliases) + tvf.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan) + } + + /** + * Create an inline table (a virtual table in Hive parlance). + */ + override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { + // Get the backing expressions. + val rows = ctx.expression.asScala.map { e => + expression(e) match { + // inline table comes in two styles: + // style 1: values (1), (2), (3) -- multiple columns are supported + // style 2: values 1, 2, 3 -- only a single column is supported here + case struct: CreateNamedStruct => struct.valExprs // style 1 + case child => Seq(child) // style 2 + } + } + + val aliases = if (ctx.tableAlias.identifierList != null) { + visitIdentifierList(ctx.tableAlias.identifierList) + } else { + Seq.tabulate(rows.head.size)(i => s"col${i + 1}") + } + + val table = UnresolvedInlineTable(aliases, rows.toSeq) + table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a join relation. This is practically the same as + * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. We could add alias names for output columns, for example: + * {{{ + * SELECT a, b, c, d FROM (src1 s1 INNER JOIN src2 s2 ON s1.id = s2.id) dst(a, b, c, d) + * }}} + */ + override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) { + val relation = plan(ctx.relation).optionalMap(ctx.sample)(withSample) + mayApplyAliasPlan(ctx.tableAlias, relation) + } + + /** + * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as + * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. We could add alias names for output columns, for example: + * {{{ + * SELECT col1, col2 FROM testData AS t(col1, col2) + * }}} + */ + override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { + val relation = plan(ctx.query).optionalMap(ctx.sample)(withSample) + if (ctx.tableAlias.strictIdentifier == null) { + // For un-aliased subqueries, use a default alias name that is not likely to conflict with + // normal subquery names, so that parent operators can only access the columns in subquery by + // unqualified names. Users can still use this special qualifier to access columns if they + // know it, but that's not recommended. + SubqueryAlias("__auto_generated_subquery_name", relation) + } else { + mayApplyAliasPlan(ctx.tableAlias, relation) + } + } + + /** + * Create an alias ([[SubqueryAlias]]) for a [[LogicalPlan]]. + */ + private def aliasPlan(alias: ParserRuleContext, plan: LogicalPlan): LogicalPlan = { + SubqueryAlias(alias.getText, plan) + } + + /** + * If aliases specified in a FROM clause, create a subquery alias ([[SubqueryAlias]]) and + * column aliases for a [[LogicalPlan]]. + */ + private def mayApplyAliasPlan(tableAlias: TableAliasContext, plan: LogicalPlan): LogicalPlan = { + if (tableAlias.strictIdentifier != null) { + val subquery = SubqueryAlias(tableAlias.strictIdentifier.getText, plan) + if (tableAlias.identifierList != null) { + val columnNames = visitIdentifierList(tableAlias.identifierList) + UnresolvedSubqueryColumnAliases(columnNames, subquery) + } else { + subquery + } + } else { + plan + } + } + + /** + * Create a Sequence of Strings for a parenthesis enclosed alias list. + */ + override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) { + visitIdentifierSeq(ctx.identifierSeq) + } + + /** + * Create a Sequence of Strings for an identifier list. + */ + override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) { + ctx.ident.asScala.map(_.getText).toSeq + } + + /* ******************************************************************************************** + * Table Identifier parsing + * ******************************************************************************************** */ + /** + * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern. + */ + override def visitTableIdentifier( + ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) { + TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) + } + + /** + * Create a [[FunctionIdentifier]] from a 'functionName' or 'databaseName'.'functionName' pattern. + */ + override def visitFunctionIdentifier( + ctx: FunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) { + FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText)) + } + + /** + * Create a multi-part identifier. + */ + override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] = + withOrigin(ctx) { + ctx.parts.asScala.map(_.getText).toSeq + } + + /* ******************************************************************************************** + * Expression parsing + * ******************************************************************************************** */ + /** + * Create an expression from the given context. This method just passes the context on to the + * visitor and only takes care of typing (We assume that the visitor returns an Expression here). + */ + protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx) + + /** + * Create sequence of expressions from the given sequence of contexts. + */ + private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = { + trees.asScala.map(expression).toSeq + } + + /** + * Create a star (i.e. all) expression; this selects all elements (in the specified object). + * Both un-targeted (global) and targeted aliases are supported. + */ + override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) { + UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText).toSeq)) + } + + /** + * Create an aliased expression if an alias is specified. Both single and multi-aliases are + * supported. + */ + override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.name != null) { + Alias(e, ctx.name.getText)() + } else if (ctx.identifierList != null) { + MultiAlias(e, visitIdentifierList(ctx.identifierList)) + } else { + e + } + } + + /** + * Combine a number of boolean expressions into a balanced expression tree. These expressions are + * either combined by a logical [[And]] or a logical [[Or]]. + * + * A balanced binary tree is created because regular left recursive trees cause considerable + * performance degradations and can cause stack overflows. + */ + override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) { + val expressionType = ctx.operator.getType + val expressionCombiner = expressionType match { + case HoodieSqlBaseParser.AND => And.apply _ + case HoodieSqlBaseParser.OR => Or.apply _ + } + + // Collect all similar left hand contexts. + val contexts = ArrayBuffer(ctx.right) + var current = ctx.left + def collectContexts: Boolean = current match { + case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType => + contexts += lbc.right + current = lbc.left + true + case _ => + contexts += current + false + } + while (collectContexts) { + // No body - all updates take place in the collectContexts. + } + + // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them + // into expressions. + val expressions = contexts.reverseMap(expression) + + // Create a balanced tree. + def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match { + case 0 => + expressions(low) + case 1 => + expressionCombiner(expressions(low), expressions(high)) + case x => + val mid = low + x / 2 + expressionCombiner( + reduceToExpressionTree(low, mid), + reduceToExpressionTree(mid + 1, high)) + } + reduceToExpressionTree(0, expressions.size - 1) + } + + /** + * Invert a boolean expression. + */ + override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) { + Not(expression(ctx.booleanExpression())) + } + + /** + * Create a filtering correlated sub-query (EXISTS). + */ + override def visitExists(ctx: ExistsContext): Expression = { + Exists(plan(ctx.query)) + } + + /** + * Create a comparison expression. This compares two expressions. The following comparison + * operators are supported: + * - Equal: '=' or '==' + * - Null-safe Equal: '<=>' + * - Not Equal: '<>' or '!=' + * - Less than: '<' + * - Less then or Equal: '<=' + * - Greater than: '>' + * - Greater then or Equal: '>=' + */ + override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode] + operator.getSymbol.getType match { + case HoodieSqlBaseParser.EQ => + EqualTo(left, right) + case HoodieSqlBaseParser.NSEQ => + EqualNullSafe(left, right) + case HoodieSqlBaseParser.NEQ | HoodieSqlBaseParser.NEQJ => + Not(EqualTo(left, right)) + case HoodieSqlBaseParser.LT => + LessThan(left, right) + case HoodieSqlBaseParser.LTE => + LessThanOrEqual(left, right) + case HoodieSqlBaseParser.GT => + GreaterThan(left, right) + case HoodieSqlBaseParser.GTE => + GreaterThanOrEqual(left, right) + } + } + + /** + * Create a predicated expression. A predicated expression is a normal expression with a + * predicate attached to it, for example: + * {{{ + * a + 1 IS NULL + * }}} + */ + override def visitPredicated(ctx: PredicatedContext): Expression = withOrigin(ctx) { + val e = expression(ctx.valueExpression) + if (ctx.predicate != null) { + withPredicate(e, ctx.predicate) + } else { + e + } + } + + /** + * Add a predicate to the given expression. Supported expressions are: + * - (NOT) BETWEEN + * - (NOT) IN + * - (NOT) LIKE (ANY | SOME | ALL) + * - (NOT) RLIKE + * - IS (NOT) NULL. + * - IS (NOT) (TRUE | FALSE | UNKNOWN) + * - IS (NOT) DISTINCT FROM + */ + private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) { + // Invert a predicate if it has a valid NOT clause. + def invertIfNotDefined(e: Expression): Expression = ctx.NOT match { + case null => e + case not => Not(e) + } + + def getValueExpressions(e: Expression): Seq[Expression] = e match { + case c: CreateNamedStruct => c.valExprs + case other => Seq(other) + } + + // Create the predicate. + ctx.kind.getType match { + case HoodieSqlBaseParser.BETWEEN => + // BETWEEN is translated to lower <= e && e <= upper + invertIfNotDefined(And( + GreaterThanOrEqual(e, expression(ctx.lower)), + LessThanOrEqual(e, expression(ctx.upper)))) + case HoodieSqlBaseParser.IN if ctx.query != null => + invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query)))) + case HoodieSqlBaseParser.IN => + invertIfNotDefined(In(e, ctx.expression.asScala.map(expression).toSeq)) + case HoodieSqlBaseParser.LIKE => + Option(ctx.quantifier).map(_.getType) match { + case Some(HoodieSqlBaseParser.ANY) | Some(HoodieSqlBaseParser.SOME) => + validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx) + val expressions = expressionList(ctx.expression) + if (expressions.forall(_.foldable) && expressions.forall(_.dataType == StringType)) { + // If there are many pattern expressions, will throw StackOverflowError. + // So we use LikeAny or NotLikeAny instead. + val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String]) + ctx.NOT match { + case null => LikeAny(e, patterns.toSeq) + case _ => NotLikeAny(e, patterns.toSeq) + } + } else { + ctx.expression.asScala.map(expression) + .map(p => invertIfNotDefined(new Like(e, p))).toSeq.reduceLeft(Or) + } + case Some(HoodieSqlBaseParser.ALL) => + validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx) + val expressions = expressionList(ctx.expression) + if (expressions.forall(_.foldable) && expressions.forall(_.dataType == StringType)) { + // If there are many pattern expressions, will throw StackOverflowError. + // So we use LikeAll or NotLikeAll instead. + val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String]) + ctx.NOT match { + case null => LikeAll(e, patterns.toSeq) + case _ => NotLikeAll(e, patterns.toSeq) + } + } else { + ctx.expression.asScala.map(expression) + .map(p => invertIfNotDefined(new Like(e, p))).toSeq.reduceLeft(And) + } + case _ => + val escapeChar = Option(ctx.escapeChar).map(string).map { str => + if (str.length != 1) { + throw new ParseException("Invalid escape string." + + "Escape string must contain only one character.", ctx) + } + str.charAt(0) + }.getOrElse('\\') + invertIfNotDefined(Like(e, expression(ctx.pattern), escapeChar)) + } + case HoodieSqlBaseParser.RLIKE => + invertIfNotDefined(RLike(e, expression(ctx.pattern))) + case HoodieSqlBaseParser.NULL if ctx.NOT != null => + IsNotNull(e) + case HoodieSqlBaseParser.NULL => + IsNull(e) + case HoodieSqlBaseParser.TRUE => ctx.NOT match { + case null => EqualNullSafe(e, Literal(true)) + case _ => Not(EqualNullSafe(e, Literal(true))) + } + case HoodieSqlBaseParser.FALSE => ctx.NOT match { + case null => EqualNullSafe(e, Literal(false)) + case _ => Not(EqualNullSafe(e, Literal(false))) + } + case HoodieSqlBaseParser.UNKNOWN => ctx.NOT match { + case null => IsUnknown(e) + case _ => IsNotUnknown(e) + } + case HoodieSqlBaseParser.DISTINCT if ctx.NOT != null => + EqualNullSafe(e, expression(ctx.right)) + case HoodieSqlBaseParser.DISTINCT => + Not(EqualNullSafe(e, expression(ctx.right))) + } + } + + /** + * Create a binary arithmetic expression. The following arithmetic operators are supported: + * - Multiplication: '*' + * - Division: '/' + * - Hive Long Division: 'DIV' + * - Modulo: '%' + * - Addition: '+' + * - Subtraction: '-' + * - Binary AND: '&' + * - Binary XOR + * - Binary OR: '|' + */ + override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + ctx.operator.getType match { + case HoodieSqlBaseParser.ASTERISK => + Multiply(left, right) + case HoodieSqlBaseParser.SLASH => + Divide(left, right) + case HoodieSqlBaseParser.PERCENT => + Remainder(left, right) + case HoodieSqlBaseParser.DIV => + IntegralDivide(left, right) + case HoodieSqlBaseParser.PLUS => + Add(left, right) + case HoodieSqlBaseParser.MINUS => + Subtract(left, right) + case HoodieSqlBaseParser.CONCAT_PIPE => + Concat(left :: right :: Nil) + case HoodieSqlBaseParser.AMPERSAND => + BitwiseAnd(left, right) + case HoodieSqlBaseParser.HAT => + BitwiseXor(left, right) + case HoodieSqlBaseParser.PIPE => + BitwiseOr(left, right) + } + } + + /** + * Create a unary arithmetic expression. The following arithmetic operators are supported: + * - Plus: '+' + * - Minus: '-' + * - Bitwise Not: '~' + */ + override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) { + val value = expression(ctx.valueExpression) + ctx.operator.getType match { + case HoodieSqlBaseParser.PLUS => + UnaryPositive(value) + case HoodieSqlBaseParser.MINUS => + UnaryMinus(value) + case HoodieSqlBaseParser.TILDE => + BitwiseNot(value) + } + } + + override def visitCurrentDatetime(ctx: CurrentDatetimeContext): Expression = withOrigin(ctx) { + if (conf.ansiEnabled) { + ctx.name.getType match { + case HoodieSqlBaseParser.CURRENT_DATE => + CurrentDate() + case HoodieSqlBaseParser.CURRENT_TIMESTAMP => + CurrentTimestamp() + } + } else { + // If the parser is not in ansi mode, we should return `UnresolvedAttribute`, in case there + // are columns named `CURRENT_DATE` or `CURRENT_TIMESTAMP`. + UnresolvedAttribute.quoted(ctx.name.getText) + } + } + + /** + * Create a [[Cast]] expression. + */ + override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { + val rawDataType = typedVisit[DataType](ctx.dataType()) + val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType) + Cast(expression(ctx.expression), dataType) + } + + /** + * Create a [[CreateStruct]] expression. + */ + override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) { + CreateStruct.create(ctx.argument.asScala.map(expression).toSeq) + } + + /** + * Create a [[First]] expression. + */ + override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + First(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression() + } + + /** + * Create a [[Last]] expression. + */ + override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + Last(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression() + } + + /** + * Create a Position expression. + */ + override def visitPosition(ctx: PositionContext): Expression = withOrigin(ctx) { + new StringLocate(expression(ctx.substr), expression(ctx.str)) + } + + /** + * Create a Extract expression. + */ + override def visitExtract(ctx: ExtractContext): Expression = withOrigin(ctx) { + val arguments = Seq(Literal(ctx.field.getText), expression(ctx.source)) + UnresolvedFunction("extract", arguments, isDistinct = false) + } + + /** + * Create a Substring/Substr expression. + */ + override def visitSubstring(ctx: SubstringContext): Expression = withOrigin(ctx) { + if (ctx.len != null) { + Substring(expression(ctx.str), expression(ctx.pos), expression(ctx.len)) + } else { + new Substring(expression(ctx.str), expression(ctx.pos)) + } + } + + /** + * Create a Trim expression. + */ + override def visitTrim(ctx: TrimContext): Expression = withOrigin(ctx) { + val srcStr = expression(ctx.srcStr) + val trimStr = Option(ctx.trimStr).map(expression) + Option(ctx.trimOption).map(_.getType).getOrElse(HoodieSqlBaseParser.BOTH) match { + case HoodieSqlBaseParser.BOTH => + StringTrim(srcStr, trimStr) + case HoodieSqlBaseParser.LEADING => + StringTrimLeft(srcStr, trimStr) + case HoodieSqlBaseParser.TRAILING => + StringTrimRight(srcStr, trimStr) + case other => + throw new ParseException("Function trim doesn't support with " + + s"type $other. Please use BOTH, LEADING or TRAILING as trim type", ctx) + } + } + + /** + * Create a Overlay expression. + */ + override def visitOverlay(ctx: OverlayContext): Expression = withOrigin(ctx) { + val input = expression(ctx.input) + val replace = expression(ctx.replace) + val position = expression(ctx.position) + val lengthOpt = Option(ctx.length).map(expression) + lengthOpt match { + case Some(length) => Overlay(input, replace, position, length) + case None => new Overlay(input, replace, position) + } + } + + /** + * Create a (windowed) Function expression. + */ + override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { + // Create the function call. + val name = ctx.functionName.getText + val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) + // Call `toSeq`, otherwise `ctx.argument.asScala.map(expression)` is `Buffer` in Scala 2.13 + val arguments = ctx.argument.asScala.map(expression).toSeq match { + case Seq(UnresolvedStar(None)) + if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct => + // Transform COUNT(*) into COUNT(1). + Seq(Literal(1)) + case expressions => + expressions + } + val filter = Option(ctx.where).map(expression(_)) + val function = UnresolvedFunction( + getFunctionIdentifier(ctx.functionName), arguments, isDistinct, filter) + + // Check if the function is evaluated in a windowed context. + ctx.windowSpec match { + case spec: WindowRefContext => + UnresolvedWindowExpression(function, visitWindowRef(spec)) + case spec: WindowDefContext => + WindowExpression(function, visitWindowDef(spec)) + case _ => function + } + } + + + /** + * Create a function database (optional) and name pair, for multipartIdentifier. + * This is used in CREATE FUNCTION, DROP FUNCTION, SHOWFUNCTIONS. + */ + protected def visitFunctionName(ctx: MultipartIdentifierContext): FunctionIdentifier = { + visitFunctionName(ctx, ctx.parts.asScala.map(_.getText).toSeq) + } + + /** + * Create a function database (optional) and name pair. + */ + protected def visitFunctionName(ctx: QualifiedNameContext): FunctionIdentifier = { + visitFunctionName(ctx, ctx.identifier().asScala.map(_.getText).toSeq) + } + + /** + * Create a function database (optional) and name pair. + */ + private def visitFunctionName(ctx: ParserRuleContext, texts: Seq[String]): FunctionIdentifier = { + texts match { + case Seq(db, fn) => FunctionIdentifier(fn, Option(db)) + case Seq(fn) => FunctionIdentifier(fn, None) + case other => + throw new ParseException(s"Unsupported function name '${texts.mkString(".")}'", ctx) + } + } + + /** + * Get a function identifier consist by database (optional) and name. + */ + protected def getFunctionIdentifier(ctx: FunctionNameContext): FunctionIdentifier = { + if (ctx.qualifiedName != null) { + visitFunctionName(ctx.qualifiedName) + } else { + FunctionIdentifier(ctx.getText, None) + } + } + + /** + * Create an [[LambdaFunction]]. + */ + override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) { + val arguments = ctx.identifier().asScala.map { name => + UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts) + } + val function = expression(ctx.expression).transformUp { + case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts) + } + LambdaFunction(function, arguments.toSeq) + } + + /** + * Create a reference to a window frame, i.e. [[WindowSpecReference]]. + */ + override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) { + WindowSpecReference(ctx.name.getText) + } + + /** + * Create a window definition, i.e. [[WindowSpecDefinition]]. + */ + override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) { + // CLUSTER BY ... | PARTITION BY ... ORDER BY ... + val partition = ctx.partition.asScala.map(expression) + val order = ctx.sortItem.asScala.map(visitSortItem) + + // RANGE/ROWS BETWEEN ... + val frameSpecOption = Option(ctx.windowFrame).map { frame => + val frameType = frame.frameType.getType match { + case HoodieSqlBaseParser.RANGE => RangeFrame + case HoodieSqlBaseParser.ROWS => RowFrame + } + + SpecifiedWindowFrame( + frameType, + visitFrameBound(frame.start), + Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow)) + } + + WindowSpecDefinition( + partition.toSeq, + order.toSeq, + frameSpecOption.getOrElse(UnspecifiedFrame)) + } + + /** + * Create or resolve a frame boundary expressions. + */ + override def visitFrameBound(ctx: FrameBoundContext): Expression = withOrigin(ctx) { + def value: Expression = { + val e = expression(ctx.expression) + validate(e.resolved && e.foldable, "Frame bound value must be a literal.", ctx) + e + } + + ctx.boundType.getType match { + case HoodieSqlBaseParser.PRECEDING if ctx.UNBOUNDED != null => + UnboundedPreceding + case HoodieSqlBaseParser.PRECEDING => + UnaryMinus(value) + case HoodieSqlBaseParser.CURRENT => + CurrentRow + case HoodieSqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null => + UnboundedFollowing + case HoodieSqlBaseParser.FOLLOWING => + value + } + } + + /** + * Create a [[CreateStruct]] expression. + */ + override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { + CreateStruct(ctx.namedExpression().asScala.map(expression).toSeq) + } + + /** + * Create a [[ScalarSubquery]] expression. + */ + override def visitSubqueryExpression( + ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) { + ScalarSubquery(plan(ctx.query)) + } + + /** + * Create a value based [[CaseWhen]] expression. This has the following SQL form: + * {{{ + * CASE [expression] + * WHEN [value] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + */ + override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { + val e = expression(ctx.value) + val branches = ctx.whenClause.asScala.map { wCtx => + (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) + } + CaseWhen(branches.toSeq, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax: + * {{{ + * CASE + * WHEN [predicate] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + * + * @param ctx the parse tree + * */ + override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) { + val branches = ctx.whenClause.asScala.map { wCtx => + (expression(wCtx.condition), expression(wCtx.result)) + } + CaseWhen(branches.toSeq, Option(ctx.elseExpression).map(expression)) + } + + /** + * Currently only regex in expressions of SELECT statements are supported; in other + * places, e.g., where `(a)?+.+` = 2, regex are not meaningful. + */ + private def canApplyRegex(ctx: ParserRuleContext): Boolean = withOrigin(ctx) { + var parent = ctx.getParent + var rtn = false + while (parent != null) { + if (parent.isInstanceOf[NamedExpressionContext]) { + rtn = true + } + parent = parent.getParent + } + rtn + } + + /** + * Create a dereference expression. The return type depends on the type of the parent. + * If the parent is an [[UnresolvedAttribute]], it can be a [[UnresolvedAttribute]] or + * a [[UnresolvedRegex]] for regex quoted in ``; if the parent is some other expression, + * it can be [[UnresolvedExtractValue]]. + */ + override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) { + val attr = ctx.fieldName.getText + expression(ctx.base) match { + case unresolved_attr @ UnresolvedAttribute(nameParts) => + ctx.fieldName.getStart.getText match { + case escapedIdentifier(columnNameRegex) + if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) => + UnresolvedRegex(columnNameRegex, Some(unresolved_attr.name), + conf.caseSensitiveAnalysis) + case _ => + UnresolvedAttribute(nameParts :+ attr) + } + case e => + UnresolvedExtractValue(e, Literal(attr)) + } + } + + /** + * Create an [[UnresolvedAttribute]] expression or a [[UnresolvedRegex]] if it is a regex + * quoted in `` + */ + override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) { + ctx.getStart.getText match { + case escapedIdentifier(columnNameRegex) + if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) => + UnresolvedRegex(columnNameRegex, None, conf.caseSensitiveAnalysis) + case _ => + UnresolvedAttribute.quoted(ctx.getText) + } + + } + + /** + * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array. + */ + override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) { + UnresolvedExtractValue(expression(ctx.value), expression(ctx.index)) + } + + /** + * Create an expression for an expression between parentheses. This is need because the ANTLR + * visitor cannot automatically convert the nested context into an expression. + */ + override def visitParenthesizedExpression( + ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) { + expression(ctx.expression) + } + + /** + * Create a [[SortOrder]] expression. + */ + override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) { + val direction = if (ctx.DESC != null) { + Descending + } else { + Ascending + } + val nullOrdering = if (ctx.FIRST != null) { + NullsFirst + } else if (ctx.LAST != null) { + NullsLast + } else { + direction.defaultNullOrdering + } + SortOrder(expression(ctx.expression), direction, nullOrdering, Seq.empty) + } + + /** + * Create a typed Literal expression. A typed literal has the following SQL syntax: + * {{{ + * [TYPE] '[VALUE]' + * }}} + * Currently Date, Timestamp, Interval and Binary typed literals are supported. + */ + override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { + val value = string(ctx.STRING) + val valueType = ctx.identifier.getText.toUpperCase(Locale.ROOT) + + def toLiteral[T](f: UTF8String => Option[T], t: DataType): Literal = { + f(UTF8String.fromString(value)).map(Literal(_, t)).getOrElse { + throw new ParseException(s"Cannot parse the $valueType value: $value", ctx) + } + } + try { + valueType match { + case "DATE" => + toLiteral(stringToDate(_, getZoneId(SQLConf.get.sessionLocalTimeZone)), DateType) + case "TIMESTAMP" => + val zoneId = getZoneId(SQLConf.get.sessionLocalTimeZone) + toLiteral(stringToTimestamp(_, zoneId), TimestampType) + case "INTERVAL" => + val interval = try { + IntervalUtils.stringToInterval(UTF8String.fromString(value)) + } catch { + case e: IllegalArgumentException => + val ex = new ParseException("Cannot parse the INTERVAL value: " + value, ctx) + ex.setStackTrace(e.getStackTrace) + throw ex + } + Literal(interval, CalendarIntervalType) + case "X" => + val padding = if (value.length % 2 != 0) "0" else "" + Literal(DatatypeConverter.parseHexBinary(padding + value)) + case other => + throw new ParseException(s"Literals of type '$other' are currently not" + + " supported.", ctx) + } + } catch { + case e: IllegalArgumentException => + val message = Option(e.getMessage).getOrElse(s"Exception parsing $valueType") + throw new ParseException(message, ctx) + } + } + + /** + * Create a NULL literal expression. + */ + override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) { + Literal(null) + } + + /** + * Create a Boolean literal expression. + */ + override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) { + if (ctx.getText.toBoolean) { + Literal.TrueLiteral + } else { + Literal.FalseLiteral + } + } + + /** + * Create an integral literal expression. The code selects the most narrow integral type + * possible, either a BigDecimal, a Long or an Integer is returned. + */ + override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) { + BigDecimal(ctx.getText) match { + case v if v.isValidInt => + Literal(v.intValue) + case v if v.isValidLong => + Literal(v.longValue) + case v => Literal(v.underlying()) + } + } + + /** + * Create a decimal literal for a regular decimal number. + */ + override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(BigDecimal(ctx.getText).underlying()) + } + + /** + * Create a decimal literal for a regular decimal number or a scientific decimal number. + */ + override def visitLegacyDecimalLiteral( + ctx: LegacyDecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(BigDecimal(ctx.getText).underlying()) + } + + /** + * Create a double literal for number with an exponent, e.g. 1E-30 + */ + override def visitExponentLiteral(ctx: ExponentLiteralContext): Literal = { + numericLiteral(ctx, ctx.getText, /* exponent values don't have a suffix */ + Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble) + } + + /** Create a numeric literal expression. */ + private def numericLiteral( + ctx: NumberContext, + rawStrippedQualifier: String, + minValue: BigDecimal, + maxValue: BigDecimal, + typeName: String)(converter: String => Any): Literal = withOrigin(ctx) { + try { + val rawBigDecimal = BigDecimal(rawStrippedQualifier) + if (rawBigDecimal < minValue || rawBigDecimal > maxValue) { + throw new ParseException(s"Numeric literal ${rawStrippedQualifier} does not " + + s"fit in range [${minValue}, ${maxValue}] for type ${typeName}", ctx) + } + Literal(converter(rawStrippedQualifier)) + } catch { + case e: NumberFormatException => + throw new ParseException(e.getMessage, ctx) + } + } + + /** + * Create a Byte Literal expression. + */ + override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Byte.MinValue, Byte.MaxValue, ByteType.simpleString)(_.toByte) + } + + /** + * Create a Short Literal expression. + */ + override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Short.MinValue, Short.MaxValue, ShortType.simpleString)(_.toShort) + } + + /** + * Create a Long Literal expression. + */ + override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Long.MinValue, Long.MaxValue, LongType.simpleString)(_.toLong) + } + + /** + * Create a Float Literal expression. + */ + override def visitFloatLiteral(ctx: FloatLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Float.MinValue, Float.MaxValue, FloatType.simpleString)(_.toFloat) + } + + /** + * Create a Double Literal expression. + */ + override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble) + } + + /** + * Create a BigDecimal Literal expression. + */ + override def visitBigDecimalLiteral(ctx: BigDecimalLiteralContext): Literal = { + val raw = ctx.getText.substring(0, ctx.getText.length - 2) + try { + Literal(BigDecimal(raw).underlying()) + } catch { + case e: AnalysisException => + throw new ParseException(e.message, ctx) + } + } + + /** + * Create a String literal expression. + */ + override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { + Literal(createString(ctx)) + } + + /** + * Create a String from a string literal context. This supports multiple consecutive string + * literals, these are concatenated, for example this expression "'hello' 'world'" will be + * converted into "helloworld". + * + * Special characters can be escaped by using Hive/C-style escaping. + */ + private def createString(ctx: StringLiteralContext): String = { + if (conf.escapedStringLiterals) { + ctx.STRING().asScala.map(stringWithoutUnescape).mkString + } else { + ctx.STRING().asScala.map(string).mkString + } + } + + /** + * Create a [[CalendarInterval]] literal expression. Two syntaxes are supported: + * - multiple unit value pairs, for instance: interval 2 months 2 days. + * - from-to unit, for instance: interval '1-2' year to month. + */ + override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) { + Literal(parseIntervalLiteral(ctx), CalendarIntervalType) + } + + /** + * Create a [[CalendarInterval]] object + */ + protected def parseIntervalLiteral(ctx: IntervalContext): CalendarInterval = withOrigin(ctx) { + if (ctx.errorCapturingMultiUnitsInterval != null) { + val innerCtx = ctx.errorCapturingMultiUnitsInterval + if (innerCtx.unitToUnitInterval != null) { + throw new ParseException( + "Can only have a single from-to unit in the interval literal syntax", + innerCtx.unitToUnitInterval) + } + visitMultiUnitsInterval(innerCtx.multiUnitsInterval) + } else if (ctx.errorCapturingUnitToUnitInterval != null) { + val innerCtx = ctx.errorCapturingUnitToUnitInterval + if (innerCtx.error1 != null || innerCtx.error2 != null) { + val errorCtx = if (innerCtx.error1 != null) innerCtx.error1 else innerCtx.error2 + throw new ParseException( + "Can only have a single from-to unit in the interval literal syntax", + errorCtx) + } + visitUnitToUnitInterval(innerCtx.body) + } else { + throw new ParseException("at least one time unit should be given for interval literal", ctx) + } + } + + /** + * Creates a [[CalendarInterval]] with multiple unit value pairs, e.g. 1 YEAR 2 DAYS. + */ + override def visitMultiUnitsInterval(ctx: MultiUnitsIntervalContext): CalendarInterval = { + withOrigin(ctx) { + val units = ctx.unit.asScala + val values = ctx.intervalValue().asScala + try { + assert(units.length == values.length) + val kvs = units.indices.map { i => + val u = units(i).getText + val v = if (values(i).STRING() != null) { + val value = string(values(i).STRING()) + // SPARK-32840: For invalid cases, e.g. INTERVAL '1 day 2' hour, + // INTERVAL 'interval 1' day, we need to check ahead before they are concatenated with + // units and become valid ones, e.g. '1 day 2 hour'. + // Ideally, we only ensure the value parts don't contain any units here. + if (value.exists(Character.isLetter)) { + throw new ParseException("Can only use numbers in the interval value part for" + + s" multiple unit value pairs interval form, but got invalid value: $value", ctx) + } + value + } else { + values(i).getText + } + UTF8String.fromString(" " + v + " " + u) + } + IntervalUtils.stringToInterval(UTF8String.concat(kvs: _*)) + } catch { + case i: IllegalArgumentException => + val e = new ParseException(i.getMessage, ctx) + e.setStackTrace(i.getStackTrace) + throw e + } + } + } + + /** + * Creates a [[CalendarInterval]] with from-to unit, e.g. '2-1' YEAR TO MONTH. + */ + override def visitUnitToUnitInterval(ctx: UnitToUnitIntervalContext): CalendarInterval = { + withOrigin(ctx) { + val value = Option(ctx.intervalValue.STRING).map(string).getOrElse { + throw new ParseException("The value of from-to unit must be a string", ctx.intervalValue) + } + try { + val from = ctx.from.getText.toLowerCase(Locale.ROOT) + val to = ctx.to.getText.toLowerCase(Locale.ROOT) + (from, to) match { + case ("year", "month") => + IntervalUtils.fromYearMonthString(value) + case ("day", "hour") => + IntervalUtils.fromDayTimeString(value, IntervalUnit.DAY, IntervalUnit.HOUR) + case ("day", "minute") => + IntervalUtils.fromDayTimeString(value, IntervalUnit.DAY, IntervalUnit.MINUTE) + case ("day", "second") => + IntervalUtils.fromDayTimeString(value, IntervalUnit.DAY, IntervalUnit.SECOND) + case ("hour", "minute") => + IntervalUtils.fromDayTimeString(value, IntervalUnit.HOUR, IntervalUnit.MINUTE) + case ("hour", "second") => + IntervalUtils.fromDayTimeString(value, IntervalUnit.HOUR, IntervalUnit.SECOND) + case ("minute", "second") => + IntervalUtils.fromDayTimeString(value, IntervalUnit.MINUTE, IntervalUnit.SECOND) + case _ => + throw new ParseException(s"Intervals FROM $from TO $to are not supported.", ctx) + } + } catch { + // Handle Exceptions thrown by CalendarInterval + case e: IllegalArgumentException => + val pe = new ParseException(e.getMessage, ctx) + pe.setStackTrace(e.getStackTrace) + throw pe + } + } + } + + /* ******************************************************************************************** + * DataType parsing + * ******************************************************************************************** */ + + /** + * Resolve/create a primitive type. + */ + override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { + val dataType = ctx.identifier.getText.toLowerCase(Locale.ROOT) + (dataType, ctx.INTEGER_VALUE().asScala.toList) match { + case ("boolean", Nil) => BooleanType + case ("tinyint" | "byte", Nil) => ByteType + case ("smallint" | "short", Nil) => ShortType + case ("int" | "integer", Nil) => IntegerType + case ("bigint" | "long", Nil) => LongType + case ("float" | "real", Nil) => FloatType + case ("double", Nil) => DoubleType + case ("date", Nil) => DateType + case ("timestamp", Nil) => TimestampType + case ("string", Nil) => StringType + case ("character" | "char", length :: Nil) => CharType(length.getText.toInt) + case ("varchar", length :: Nil) => VarcharType(length.getText.toInt) + case ("binary", Nil) => BinaryType + case ("decimal" | "dec" | "numeric", Nil) => DecimalType.USER_DEFAULT + case ("decimal" | "dec" | "numeric", precision :: Nil) => + DecimalType(precision.getText.toInt, 0) + case ("decimal" | "dec" | "numeric", precision :: scale :: Nil) => + DecimalType(precision.getText.toInt, scale.getText.toInt) + case ("void", Nil) => NullType + case ("interval", Nil) => CalendarIntervalType + case (dt, params) => + val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt + throw new ParseException(s"DataType $dtStr is not supported.", ctx) + } + } + + /** + * Create a complex DataType. Arrays, Maps and Structures are supported. + */ + override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) { + ctx.complex.getType match { + case HoodieSqlBaseParser.ARRAY => + ArrayType(typedVisit(ctx.dataType(0))) + case HoodieSqlBaseParser.MAP => + MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) + case HoodieSqlBaseParser.STRUCT => + StructType(Option(ctx.complexColTypeList).toSeq.flatMap(visitComplexColTypeList)) + } + } + + /** + * Create top level table schema. + */ + protected def createSchema(ctx: ColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.colType().asScala.map(visitColType).toSeq + } + + /** + * Create a top level [[StructField]] from a column definition. + */ + override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + + val builder = new MetadataBuilder + // Add comment to metadata + Option(commentSpec()).map(visitCommentSpec).foreach { + builder.putString("comment", _) + } + + StructField( + name = colName.getText, + dataType = typedVisit[DataType](ctx.dataType), + nullable = NULL == null, + metadata = builder.build()) + } + + /** + * Create a [[StructType]] from a sequence of [[StructField]]s. + */ + protected def createStructType(ctx: ComplexColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitComplexColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitComplexColTypeList( + ctx: ComplexColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.complexColType().asScala.map(visitComplexColType).toSeq + } + + /** + * Create a [[StructField]] from a column definition. + */ + override def visitComplexColType(ctx: ComplexColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + val structField = StructField( + name = identifier.getText, + dataType = typedVisit(dataType()), + nullable = NULL == null) + Option(commentSpec).map(visitCommentSpec).map(structField.withComment).getOrElse(structField) + } + + /** + * Create a location string. + */ + override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) { + string(ctx.STRING) + } + + /** + * Create an optional location string. + */ + protected def visitLocationSpecList(ctx: java.util.List[LocationSpecContext]): Option[String] = { + ctx.asScala.headOption.map(visitLocationSpec) + } + + /** + * Create a comment string. + */ + override def visitCommentSpec(ctx: CommentSpecContext): String = withOrigin(ctx) { + string(ctx.STRING) + } + + /** + * Create an optional comment string. + */ + protected def visitCommentSpecList(ctx: java.util.List[CommentSpecContext]): Option[String] = { + ctx.asScala.headOption.map(visitCommentSpec) + } + + /** + * Create a [[BucketSpec]]. + */ + override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) { + BucketSpec( + ctx.INTEGER_VALUE.getText.toInt, + visitIdentifierList(ctx.identifierList), + Option(ctx.orderedIdentifierList) + .toSeq + .flatMap(_.orderedIdentifier.asScala) + .map { orderedIdCtx => + Option(orderedIdCtx.ordering).map(_.getText).foreach { dir => + if (dir.toLowerCase(Locale.ROOT) != "asc") { + operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx) + } + } + + orderedIdCtx.ident.getText + }) + } + + /** + * Convert a table property list into a key-value map. + * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]]. + */ + override def visitTablePropertyList( + ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { + val properties = ctx.tableProperty.asScala.map { property => + val key = visitTablePropertyKey(property.key) + val value = visitTablePropertyValue(property.value) + key -> value + } + // Check for duplicate property names. + checkDuplicateKeys(properties.toSeq, ctx) + properties.toMap + } + + /** + * Parse a key-value map from a [[TablePropertyListContext]], assuming all values are specified. + */ + def visitPropertyKeyValues(ctx: TablePropertyListContext): Map[String, String] = { + val props = visitTablePropertyList(ctx) + val badKeys = props.collect { case (key, null) => key } + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values must be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props + } + + /** + * Parse a list of keys from a [[TablePropertyListContext]], assuming no values are specified. + */ + def visitPropertyKeys(ctx: TablePropertyListContext): Seq[String] = { + val props = visitTablePropertyList(ctx) + val badKeys = props.filter { case (_, v) => v != null }.keys + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values should not be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props.keys.toSeq + } + + /** + * A table property key can either be String or a collection of dot separated elements. This + * function extracts the property key based on whether its a string literal or a table property + * identifier. + */ + override def visitTablePropertyKey(key: TablePropertyKeyContext): String = { + if (key.STRING != null) { + string(key.STRING) + } else { + key.getText + } + } + + /** + * A table property value can be String, Integer, Boolean or Decimal. This function extracts + * the property value based on whether its a string, integer, boolean or decimal literal. + */ + override def visitTablePropertyValue(value: TablePropertyValueContext): String = { + if (value == null) { + null + } else if (value.STRING != null) { + string(value.STRING) + } else if (value.booleanValue != null) { + value.getText.toLowerCase(Locale.ROOT) + } else { + value.getText + } + } + + /** + * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal). + */ + type TableHeader = (Seq[String], Boolean, Boolean, Boolean) + + /** + * Type to keep track of table clauses: + * - partition transforms + * - partition columns + * - bucketSpec + * - properties + * - options + * - location + * - comment + * - serde + * + * Note: Partition transforms are based on existing table schema definition. It can be simple + * column names, or functions like `year(date_col)`. Partition columns are column names with data + * types like `i INT`, which should be appended to the existing table schema. + */ + type TableClauses = ( + Seq[Transform], Seq[StructField], Option[BucketSpec], Map[String, String], + Map[String, String], Option[String], Option[String], Option[SerdeInfo]) + + /** + * Validate a create table statement and return the [[TableIdentifier]]. + */ + override def visitCreateTableHeader( + ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { + val temporary = ctx.TEMPORARY != null + val ifNotExists = ctx.EXISTS != null + if (temporary && ifNotExists) { + operationNotAllowed("CREATE TEMPORARY TABLE ... IF NOT EXISTS", ctx) + } + val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText).toSeq + (multipartIdentifier, temporary, ifNotExists, ctx.EXTERNAL != null) + } + + /** + * Validate a replace table statement and return the [[TableIdentifier]]. + */ + override def visitReplaceTableHeader( + ctx: ReplaceTableHeaderContext): TableHeader = withOrigin(ctx) { + val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText).toSeq + (multipartIdentifier, false, false, false) + } + + /** + * Parse a qualified name to a multipart name. + */ + override def visitQualifiedName(ctx: QualifiedNameContext): Seq[String] = withOrigin(ctx) { + ctx.identifier.asScala.map(_.getText).toSeq + } + + /** + * Parse a list of transforms or columns. + */ + override def visitPartitionFieldList( + ctx: PartitionFieldListContext): (Seq[Transform], Seq[StructField]) = withOrigin(ctx) { + val (transforms, columns) = ctx.fields.asScala.map { + case transform: PartitionTransformContext => + (Some(visitPartitionTransform(transform)), None) + case field: PartitionColumnContext => + (None, Some(visitColType(field.colType))) + }.unzip + + (transforms.flatten.toSeq, columns.flatten.toSeq) + } + + override def visitPartitionTransform( + ctx: PartitionTransformContext): Transform = withOrigin(ctx) { + def getFieldReference( + ctx: ApplyTransformContext, + arg: V2Expression): FieldReference = { + lazy val name: String = ctx.identifier.getText + arg match { + case ref: FieldReference => + ref + case nonRef => + throw new ParseException( + s"Expected a column reference for transform $name: ${nonRef.describe}", ctx) + } + } + + def getSingleFieldReference( + ctx: ApplyTransformContext, + arguments: Seq[V2Expression]): FieldReference = { + lazy val name: String = ctx.identifier.getText + if (arguments.size > 1) { + throw new ParseException(s"Too many arguments for transform $name", ctx) + } else if (arguments.isEmpty) { + throw new ParseException(s"Not enough arguments for transform $name", ctx) + } else { + getFieldReference(ctx, arguments.head) + } + } + + ctx.transform match { + case identityCtx: IdentityTransformContext => + IdentityTransform(FieldReference(typedVisit[Seq[String]](identityCtx.qualifiedName))) + + case applyCtx: ApplyTransformContext => + val arguments = applyCtx.argument.asScala.map(visitTransformArgument).toSeq + + applyCtx.identifier.getText match { + case "bucket" => + val numBuckets: Int = arguments.head match { + case LiteralValue(shortValue, ShortType) => + shortValue.asInstanceOf[Short].toInt + case LiteralValue(intValue, IntegerType) => + intValue.asInstanceOf[Int] + case LiteralValue(longValue, LongType) => + longValue.asInstanceOf[Long].toInt + case lit => + throw new ParseException(s"Invalid number of buckets: ${lit.describe}", applyCtx) + } + + val fields = arguments.tail.map(arg => getFieldReference(applyCtx, arg)) + + BucketTransform(LiteralValue(numBuckets, IntegerType), fields.toSeq) + + case "years" => + YearsTransform(getSingleFieldReference(applyCtx, arguments)) + + case "months" => + MonthsTransform(getSingleFieldReference(applyCtx, arguments)) + + case "days" => + DaysTransform(getSingleFieldReference(applyCtx, arguments)) + + case "hours" => + HoursTransform(getSingleFieldReference(applyCtx, arguments)) + + case name => + ApplyTransform(name, arguments) + } + } + } + + /** + * Parse an argument to a transform. An argument may be a field reference (qualified name) or + * a value literal. + */ + override def visitTransformArgument(ctx: TransformArgumentContext): V2Expression = { + withOrigin(ctx) { + val reference = Option(ctx.qualifiedName) + .map(typedVisit[Seq[String]]) + .map(FieldReference(_)) + val literal = Option(ctx.constant) + .map(typedVisit[Literal]) + .map(lit => LiteralValue(lit.value, lit.dataType)) + reference.orElse(literal) + .getOrElse(throw new ParseException(s"Invalid transform argument", ctx)) + } + } + + def cleanTableProperties( + ctx: ParserRuleContext, properties: Map[String, String]): Map[String, String] = { + import TableCatalog._ + val legacyOn = conf.getConf(SQLConf.LEGACY_PROPERTY_NON_RESERVED) + properties.filter { + case (PROP_PROVIDER, _) if !legacyOn => + throw new ParseException(s"$PROP_PROVIDER is a reserved table property, please use" + + s" the USING clause to specify it.", ctx) + case (PROP_PROVIDER, _) => false + case (PROP_LOCATION, _) if !legacyOn => + throw new ParseException(s"$PROP_LOCATION is a reserved table property, please use" + + s" the LOCATION clause to specify it.", ctx) + case (PROP_LOCATION, _) => false + case (PROP_OWNER, _) if !legacyOn => + throw new ParseException(s"$PROP_OWNER is a reserved table property, it will be" + + s" set to the current user", ctx) + case (PROP_OWNER, _) => false + case _ => true + } + } + + def cleanTableOptions( + ctx: ParserRuleContext, + options: Map[String, String], + location: Option[String]): (Map[String, String], Option[String]) = { + var path = location + val filtered = cleanTableProperties(ctx, options).filter { + case (k, v) if k.equalsIgnoreCase("path") && path.nonEmpty => + throw new ParseException(s"Duplicated table paths found: '${path.get}' and '$v'. LOCATION" + + s" and the case insensitive key 'path' in OPTIONS are all used to indicate the custom" + + s" table path, you can only specify one of them.", ctx) + case (k, v) if k.equalsIgnoreCase("path") => + path = Some(v) + false + case _ => true + } + (filtered, path) + } + + /** + * Create a [[SerdeInfo]] for creating tables. + * + * Format: STORED AS (name | INPUTFORMAT input_format OUTPUTFORMAT output_format) + */ + override def visitCreateFileFormat(ctx: CreateFileFormatContext): SerdeInfo = withOrigin(ctx) { + (ctx.fileFormat, ctx.storageHandler) match { + // Expected format: INPUTFORMAT input_format OUTPUTFORMAT output_format + case (c: TableFileFormatContext, null) => + SerdeInfo(formatClasses = Some(FormatClasses(string(c.inFmt), string(c.outFmt)))) + // Expected format: SEQUENCEFILE | TEXTFILE | RCFILE | ORC | PARQUET | AVRO + case (c: GenericFileFormatContext, null) => + SerdeInfo(storedAs = Some(c.identifier.getText)) + case (null, storageHandler) => + operationNotAllowed("STORED BY", ctx) + case _ => + throw new ParseException("Expected either STORED AS or STORED BY, not both", ctx) + } + } + + /** + * Create a [[SerdeInfo]] used for creating tables. + * + * Example format: + * {{{ + * SERDE serde_name [WITH SERDEPROPERTIES (k1=v1, k2=v2, ...)] + * }}} + * + * OR + * + * {{{ + * DELIMITED [FIELDS TERMINATED BY char [ESCAPED BY char]] + * [COLLECTION ITEMS TERMINATED BY char] + * [MAP KEYS TERMINATED BY char] + * [LINES TERMINATED BY char] + * [NULL DEFINED AS char] + * }}} + */ + def visitRowFormat(ctx: RowFormatContext): SerdeInfo = withOrigin(ctx) { + ctx match { + case serde: RowFormatSerdeContext => visitRowFormatSerde(serde) + case delimited: RowFormatDelimitedContext => visitRowFormatDelimited(delimited) + } + } + + /** + * Create SERDE row format name and properties pair. + */ + override def visitRowFormatSerde(ctx: RowFormatSerdeContext): SerdeInfo = withOrigin(ctx) { + import ctx._ + SerdeInfo( + serde = Some(string(name)), + serdeProperties = Option(tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty)) + } + + /** + * Create a delimited row format properties object. + */ + override def visitRowFormatDelimited( + ctx: RowFormatDelimitedContext): SerdeInfo = withOrigin(ctx) { + // Collect the entries if any. + def entry(key: String, value: Token): Seq[(String, String)] = { + Option(value).toSeq.map(x => key -> string(x)) + } + // TODO we need proper support for the NULL format. + val entries = + entry("field.delim", ctx.fieldsTerminatedBy) ++ + entry("serialization.format", ctx.fieldsTerminatedBy) ++ + entry("escape.delim", ctx.escapedBy) ++ + // The following typo is inherited from Hive... + entry("colelction.delim", ctx.collectionItemsTerminatedBy) ++ + entry("mapkey.delim", ctx.keysTerminatedBy) ++ + Option(ctx.linesSeparatedBy).toSeq.map { token => + val value = string(token) + validate( + value == "\n", + s"LINES TERMINATED BY only supports newline '\\n' right now: $value", + ctx) + "line.delim" -> value + } + SerdeInfo(serdeProperties = entries.toMap) + } + + /** + * Throw a [[ParseException]] if the user specified incompatible SerDes through ROW FORMAT + * and STORED AS. + * + * The following are allowed. Anything else is not: + * ROW FORMAT SERDE ... STORED AS [SEQUENCEFILE | RCFILE | TEXTFILE] + * ROW FORMAT DELIMITED ... STORED AS TEXTFILE + * ROW FORMAT ... STORED AS INPUTFORMAT ... OUTPUTFORMAT ... + */ + protected def validateRowFormatFileFormat( + rowFormatCtx: RowFormatContext, + createFileFormatCtx: CreateFileFormatContext, + parentCtx: ParserRuleContext): Unit = { + if (!(rowFormatCtx == null || createFileFormatCtx == null)) { + (rowFormatCtx, createFileFormatCtx.fileFormat) match { + case (_, ffTable: TableFileFormatContext) => // OK + case (rfSerde: RowFormatSerdeContext, ffGeneric: GenericFileFormatContext) => + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { + case ("sequencefile" | "textfile" | "rcfile") => // OK + case fmt => + operationNotAllowed( + s"ROW FORMAT SERDE is incompatible with format '$fmt', which also specifies a serde", + parentCtx) + } + case (rfDelimited: RowFormatDelimitedContext, ffGeneric: GenericFileFormatContext) => + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { + case "textfile" => // OK + case fmt => operationNotAllowed( + s"ROW FORMAT DELIMITED is only compatible with 'textfile', not '$fmt'", parentCtx) + } + case _ => + // should never happen + def str(ctx: ParserRuleContext): String = { + (0 until ctx.getChildCount).map { i => ctx.getChild(i).getText }.mkString(" ") + } + + operationNotAllowed( + s"Unexpected combination of ${str(rowFormatCtx)} and ${str(createFileFormatCtx)}", + parentCtx) + } + } + } + + protected def validateRowFormatFileFormat( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + parentCtx: ParserRuleContext): Unit = { + if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) { + validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx) + } + } + + override def visitCreateTableClauses(ctx: CreateTableClausesContext): TableClauses = { + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx) + checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx) + checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + + if (ctx.skewSpec.size > 0) { + operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx) + } + + val (partTransforms, partCols) = + Option(ctx.partitioning).map(visitPartitionFieldList).getOrElse((Nil, Nil)) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) + val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) + val cleanedProperties = cleanTableProperties(ctx, properties) + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) + val location = visitLocationSpecList(ctx.locationSpec()) + val (cleanedOptions, newLocation) = cleanTableOptions(ctx, options, location) + val comment = visitCommentSpecList(ctx.commentSpec()) + val serdeInfo = + getSerdeInfo(ctx.rowFormat.asScala.toSeq, ctx.createFileFormat.asScala.toSeq, ctx) + (partTransforms, partCols, bucketSpec, cleanedProperties, cleanedOptions, newLocation, comment, + serdeInfo) + } + + protected def getSerdeInfo( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + ctx: ParserRuleContext, + skipCheck: Boolean = false): Option[SerdeInfo] = { + if (!skipCheck) validateRowFormatFileFormat(rowFormatCtx, createFileFormatCtx, ctx) + val rowFormatSerdeInfo = rowFormatCtx.map(visitRowFormat) + val fileFormatSerdeInfo = createFileFormatCtx.map(visitCreateFileFormat) + (fileFormatSerdeInfo ++ rowFormatSerdeInfo).reduceLeftOption((l, r) => l.merge(r)) + } + + private def partitionExpressions( + partTransforms: Seq[Transform], + partCols: Seq[StructField], + ctx: ParserRuleContext): Seq[Transform] = { + if (partTransforms.nonEmpty) { + if (partCols.nonEmpty) { + val references = partTransforms.map(_.describe()).mkString(", ") + val columns = partCols + .map(field => s"${field.name} ${field.dataType.simpleString}") + .mkString(", ") + operationNotAllowed( + s"""PARTITION BY: Cannot mix partition expressions and partition columns: + |Expressions: $references + |Columns: $columns""".stripMargin, ctx) + + } + partTransforms + } else { + // columns were added to create the schema. convert to column references + partCols.map { column => + IdentityTransform(FieldReference(Seq(column.name))) + } + } + } + + /** + * Create a table, returning a [[CreateTableStatement]] logical plan. + * + * Expected format: + * {{{ + * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name + * [USING table_provider] + * create_table_clauses + * [[AS] select_statement]; + * + * create_table_clauses (order insensitive): + * [PARTITIONED BY (partition_fields)] + * [OPTIONS table_property_list] + * [ROW FORMAT row_format] + * [STORED AS file_format] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [TBLPROPERTIES (property_name=property_value, ...)] + * + * partition_fields: + * col_name, transform(col_name), transform(constant, col_name), ... | + * col_name data_type [NOT NULL] [COMMENT col_comment], ... + * }}} + */ + override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { + val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + + val columns = Option(ctx.colTypeList()).map(visitColTypeList).getOrElse(Nil) + val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) + val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo) = + visitCreateTableClauses(ctx.createTableClauses()) + + if (provider.isDefined && serdeInfo.isDefined) { + operationNotAllowed(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx) + } + + if (temp) { + val asSelect = if (ctx.query == null) "" else " AS ..." + operationNotAllowed( + s"CREATE TEMPORARY TABLE ...$asSelect, use CREATE TEMPORARY VIEW instead", ctx) + } + + val partitioning = partitionExpressions(partTransforms, partCols, ctx) + + Option(ctx.query).map(plan) match { + case Some(_) if columns.nonEmpty => + operationNotAllowed( + "Schema may not be specified in a Create Table As Select (CTAS) statement", + ctx) + + case Some(_) if partCols.nonEmpty => + // non-reference partition columns are not allowed because schema can't be specified + operationNotAllowed( + "Partition column types may not be specified in Create Table As Select (CTAS)", + ctx) + + case Some(query) => + CreateTableAsSelectStatement( + table, query, partitioning, bucketSpec, properties, provider, options, location, comment, + writeOptions = Map.empty, serdeInfo, external = external, ifNotExists = ifNotExists) + + case _ => + // Note: table schema includes both the table columns list and the partition columns + // with data type. + val schema = StructType(columns ++ partCols) + CreateTableStatement(table, schema, partitioning, bucketSpec, properties, provider, + options, location, comment, serdeInfo, external = external, ifNotExists = ifNotExists) + } + } + + /** + * Replace a table, returning a [[ReplaceTableStatement]] logical plan. + * + * Expected format: + * {{{ + * [CREATE OR] REPLACE TABLE [db_name.]table_name + * [USING table_provider] + * replace_table_clauses + * [[AS] select_statement]; + * + * replace_table_clauses (order insensitive): + * [OPTIONS table_property_list] + * [PARTITIONED BY (partition_fields)] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [TBLPROPERTIES (property_name=property_value, ...)] + * + * partition_fields: + * col_name, transform(col_name), transform(constant, col_name), ... | + * col_name data_type [NOT NULL] [COMMENT col_comment], ... + * }}} + */ + override def visitReplaceTable(ctx: ReplaceTableContext): LogicalPlan = withOrigin(ctx) { + val (table, temp, ifNotExists, external) = visitReplaceTableHeader(ctx.replaceTableHeader) + val orCreate = ctx.replaceTableHeader().CREATE() != null + + if (temp) { + val action = if (orCreate) "CREATE OR REPLACE" else "REPLACE" + operationNotAllowed(s"$action TEMPORARY TABLE ..., use $action TEMPORARY VIEW instead.", ctx) + } + + if (external) { + operationNotAllowed("REPLACE EXTERNAL TABLE ...", ctx) + } + + if (ifNotExists) { + operationNotAllowed("REPLACE ... IF NOT EXISTS, use CREATE IF NOT EXISTS instead", ctx) + } + + val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo) = + visitCreateTableClauses(ctx.createTableClauses()) + val columns = Option(ctx.colTypeList()).map(visitColTypeList).getOrElse(Nil) + val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) + + if (provider.isDefined && serdeInfo.isDefined) { + operationNotAllowed(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx) + } + + val partitioning = partitionExpressions(partTransforms, partCols, ctx) + + Option(ctx.query).map(plan) match { + case Some(_) if columns.nonEmpty => + operationNotAllowed( + "Schema may not be specified in a Replace Table As Select (RTAS) statement", + ctx) + + case Some(_) if partCols.nonEmpty => + // non-reference partition columns are not allowed because schema can't be specified + operationNotAllowed( + "Partition column types may not be specified in Replace Table As Select (RTAS)", + ctx) + + case Some(query) => + ReplaceTableAsSelectStatement(table, query, partitioning, bucketSpec, properties, + provider, options, location, comment, writeOptions = Map.empty, serdeInfo, + orCreate = orCreate) + + case _ => + // Note: table schema includes both the table columns list and the partition columns + // with data type. + val schema = StructType(columns ++ partCols) + ReplaceTableStatement(table, schema, partitioning, bucketSpec, properties, provider, + options, location, comment, serdeInfo, orCreate = orCreate) + } + } + + /** + * Parse new column info from ADD COLUMN into a QualifiedColType. + */ + override def visitQualifiedColTypeWithPosition( + ctx: QualifiedColTypeWithPositionContext): QualifiedColType = withOrigin(ctx) { + QualifiedColType( + name = typedVisit[Seq[String]](ctx.name), + dataType = typedVisit[DataType](ctx.dataType), + nullable = ctx.NULL == null, + comment = Option(ctx.commentSpec()).map(visitCommentSpec), + position = Option(ctx.colPosition).map(typedVisit[ColumnPosition])) + } + + /** + * Create or replace a view. This creates a [[CreateViewStatement]] + * + * For example: + * {{{ + * CREATE [OR REPLACE] [[GLOBAL] TEMPORARY] VIEW [IF NOT EXISTS] multi_part_name + * [(column_name [COMMENT column_comment], ...) ] + * create_view_clauses + * + * AS SELECT ...; + * + * create_view_clauses (order insensitive): + * [COMMENT view_comment] + * [TBLPROPERTIES (property_name = property_value, ...)] + * }}} + */ + override def visitCreateView(ctx: CreateViewContext): LogicalPlan = withOrigin(ctx) { + if (!ctx.identifierList.isEmpty) { + operationNotAllowed("CREATE VIEW ... PARTITIONED ON", ctx) + } + + checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED ON", ctx) + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + + val userSpecifiedColumns = Option(ctx.identifierCommentList).toSeq.flatMap { icl => + icl.identifierComment.asScala.map { ic => + ic.identifier.getText -> Option(ic.commentSpec()).map(visitCommentSpec) + } + } + + val properties = ctx.tablePropertyList.asScala.headOption.map(visitPropertyKeyValues) + .getOrElse(Map.empty) + if (ctx.TEMPORARY != null && properties.nonEmpty) { + operationNotAllowed("TBLPROPERTIES can't coexist with CREATE TEMPORARY VIEW", ctx) + } + + val viewType = if (ctx.TEMPORARY == null) { + PersistedView + } else if (ctx.GLOBAL != null) { + GlobalTempView + } else { + LocalTempView + } + CreateViewStatement( + visitMultipartIdentifier(ctx.multipartIdentifier), + userSpecifiedColumns, + visitCommentSpecList(ctx.commentSpec()), + properties, + Option(source(ctx.query)), + plan(ctx.query), + ctx.EXISTS != null, + ctx.REPLACE != null, + viewType) + } + + /** + * Alter the query of a view. This creates a [[AlterViewAsStatement]] + * + * For example: + * {{{ + * ALTER VIEW multi_part_name AS SELECT ...; + * }}} + */ + override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) { + AlterViewAsStatement( + visitMultipartIdentifier(ctx.multipartIdentifier), + originalText = source(ctx.query), + query = plan(ctx.query)) + } +} + +/** + * A container for holding named common table expressions (CTEs) and a query plan. + * This operator will be removed during analysis and the relations will be substituted into child. + * + * @param child The final query of this CTE. + * @param cteRelations A sequence of pair (alias, the CTE definition) that this CTE defined + * Each CTE can see the base tables and the previously defined CTEs only. + */ +case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) extends UnaryNode { + override def output: Seq[Attribute] = child.output + + override def simpleString(maxFields: Int): String = { + val cteAliases = truncatedString(cteRelations.map(_._1), "[", ", ", "]", maxFields) + s"CTE $cteAliases" + } + + override def innerChildren: Seq[LogicalPlan] = cteRelations.map(_._2) + + def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = this +} + diff --git a/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_1ExtendedSqlParser.scala b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_1ExtendedSqlParser.scala new file mode 100644 index 0000000000000..0d07ad12f1025 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3.1.x/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_1ExtendedSqlParser.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser + +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException} +import org.antlr.v4.runtime.tree.TerminalNodeImpl +import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser.{NonReservedContext, QuotedIdentifierContext} +import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseListener, HoodieSqlBaseLexer, HoodieSqlBaseParser} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.{ParseErrorListener, ParseException, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, SparkSession} + +class HoodieSpark3_1ExtendedSqlParser(session: SparkSession, delegate: ParserInterface) + extends ParserInterface with Logging { + + private lazy val conf = session.sqlContext.conf + private lazy val builder = new HoodieSpark3_1ExtendedSqlAstBuilder(conf, delegate) + + override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => + builder.visit(parser.singleStatement()) match { + case plan: LogicalPlan => plan + case _=> delegate.parsePlan(sqlText) + } + } + + override def parseExpression(sqlText: String): Expression = delegate.parseExpression(sqlText) + + override def parseTableIdentifier(sqlText: String): TableIdentifier = + delegate.parseTableIdentifier(sqlText) + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = + delegate.parseFunctionIdentifier(sqlText) + + override def parseTableSchema(sqlText: String): StructType = delegate.parseTableSchema(sqlText) + + override def parseDataType(sqlText: String): DataType = delegate.parseDataType(sqlText) + + protected def parse[T](command: String)(toResult: HoodieSqlBaseParser => T): T = { + logDebug(s"Parsing command: $command") + + val lexer = new HoodieSqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command))) + lexer.removeErrorListeners() + lexer.addErrorListener(ParseErrorListener) + + val tokenStream = new CommonTokenStream(lexer) + val parser = new HoodieSqlBaseParser(tokenStream) + parser.addParseListener(PostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(ParseErrorListener) +// parser.legacy_setops_precedence_enabled = conf.setOpsPrecedenceEnforced + parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled + parser.SQL_standard_keyword_behavior = conf.ansiEnabled + + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) + } + catch { + case e: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.seek(0) // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) + } + } + catch { + case e: ParseException if e.command.isDefined => + throw e + case e: ParseException => + throw e.withCommand(command) + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new ParseException(Option(command), e.message, position, position) + } + } + + override def parseMultipartIdentifier(sqlText: String): Seq[String] = { + delegate.parseMultipartIdentifier(sqlText) + } +} + +/** + * Fork from `org.apache.spark.sql.catalyst.parser.UpperCaseCharStream`. + */ +class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { + override def consume(): Unit = wrapped.consume + override def getSourceName(): String = wrapped.getSourceName + override def index(): Int = wrapped.index + override def mark(): Int = wrapped.mark + override def release(marker: Int): Unit = wrapped.release(marker) + override def seek(where: Int): Unit = wrapped.seek(where) + override def size(): Int = wrapped.size + + override def getText(interval: Interval): String = { + // ANTLR 4.7's CodePointCharStream implementations have bugs when + // getText() is called with an empty stream, or intervals where + // the start > end. See + // https://github.com/antlr/antlr4/commit/ac9f7530 for one fix + // that is not yet in a released ANTLR artifact. + if (size() > 0 && (interval.b - interval.a >= 0)) { + wrapped.getText(interval) + } else { + "" + } + } + // scalastyle:off + override def LA(i: Int): Int = { + // scalastyle:on + val la = wrapped.LA(i) + if (la == 0 || la == IntStream.EOF) la + else Character.toUpperCase(la) + } +} + +/** + * Fork from `org.apache.spark.sql.catalyst.parser.PostProcessor`. + */ +case object PostProcessor extends HoodieSqlBaseBaseListener { + + /** Remove the back ticks from an Identifier. */ + override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = { + replaceTokenByIdentifier(ctx, 1) { token => + // Remove the double back ticks in the string. + token.setText(token.getText.replace("``", "`")) + token + } + } + + /** Treat non-reserved keywords as Identifiers. */ + override def exitNonReserved(ctx: NonReservedContext): Unit = { + replaceTokenByIdentifier(ctx, 0)(identity) + } + + private def replaceTokenByIdentifier( + ctx: ParserRuleContext, + stripMargins: Int)( + f: CommonToken => CommonToken = identity): Unit = { + val parent = ctx.getParent + parent.removeLastChild() + val token = ctx.getChild(0).getPayload.asInstanceOf[Token] + val newToken = new CommonToken( + new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), + HoodieSqlBaseParser.IDENTIFIER, + token.getChannel, + token.getStartIndex + stripMargins, + token.getStopIndex - stripMargins) + parent.addChild(new TerminalNodeImpl(f(newToken))) + } +} diff --git a/hudi-spark-datasource/hudi-spark3/pom.xml b/hudi-spark-datasource/hudi-spark3/pom.xml index d8dba8384886c..1900f82474999 100644 --- a/hudi-spark-datasource/hudi-spark3/pom.xml +++ b/hudi-spark-datasource/hudi-spark3/pom.xml @@ -144,6 +144,24 @@ org.jacoco jacoco-maven-plugin + + org.antlr + antlr4-maven-plugin + ${antlr.version} + + + + antlr4 + + + + + true + true + ../hudi-spark3/src/main/antlr4 + ../hudi-spark3/src/main/antlr4/imports + + @@ -157,7 +175,7 @@ org.apache.spark spark-sql_2.12 - ${spark3.version} + ${spark3.2.version} true diff --git a/hudi-spark-datasource/hudi-spark3/src/main/antlr4/imports/SqlBase.g4 b/hudi-spark-datasource/hudi-spark3/src/main/antlr4/imports/SqlBase.g4 new file mode 100644 index 0000000000000..d4e1e48351ccc --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3/src/main/antlr4/imports/SqlBase.g4 @@ -0,0 +1,1908 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. + */ + +// The parser file is forked from spark 3.2.0's SqlBase.g4. +grammar SqlBase; + +@parser::members { + /** + * When false, INTERSECT is given the greater precedence over the other set + * operations (UNION, EXCEPT and MINUS) as per the SQL standard. + */ + public boolean legacy_setops_precedence_enabled = false; + + /** + * When false, a literal with an exponent would be converted into + * double type rather than decimal type. + */ + public boolean legacy_exponent_literal_as_decimal_enabled = false; + + /** + * When true, the behavior of keywords follows ANSI SQL standard. + */ + public boolean SQL_standard_keyword_behavior = false; +} + +@lexer::members { + /** + * Verify whether current token is a valid decimal token (which contains dot). + * Returns true if the character that follows the token is not a digit or letter or underscore. + * + * For example: + * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. + * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. + * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed + * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' + * which is not a digit or letter or underscore. + */ + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } + + /** + * This method will be called when we see '/*' and try to match it as a bracketed comment. + * If the next character is '+', it should be parsed as hint later, and we cannot match + * it as a bracketed comment. + * + * Returns true if the next character is '+'. + */ + public boolean isHint() { + int nextChar = _input.LA(1); + if (nextChar == '+') { + return true; + } else { + return false; + } + } +} + +singleStatement + : statement ';'* EOF + ; + +singleExpression + : namedExpression EOF + ; + +singleTableIdentifier + : tableIdentifier EOF + ; + +singleMultipartIdentifier + : multipartIdentifier EOF + ; + +singleFunctionIdentifier + : functionIdentifier EOF + ; + +singleDataType + : dataType EOF + ; + +singleTableSchema + : colTypeList EOF + ; + +statement + : query #statementDefault + | ctes? dmlStatementNoWith #dmlStatement + | USE NAMESPACE? multipartIdentifier #use + | CREATE namespace (IF NOT EXISTS)? multipartIdentifier + (commentSpec | + locationSpec | + (WITH (DBPROPERTIES | PROPERTIES) tablePropertyList))* #createNamespace + | ALTER namespace multipartIdentifier + SET (DBPROPERTIES | PROPERTIES) tablePropertyList #setNamespaceProperties + | ALTER namespace multipartIdentifier + SET locationSpec #setNamespaceLocation + | DROP namespace (IF EXISTS)? multipartIdentifier + (RESTRICT | CASCADE)? #dropNamespace + | SHOW (DATABASES | NAMESPACES) ((FROM | IN) multipartIdentifier)? + (LIKE? pattern=STRING)? #showNamespaces + | createTableHeader ('(' colTypeList ')')? tableProvider? + createTableClauses + (AS? query)? #createTable + | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier + LIKE source=tableIdentifier + (tableProvider | + rowFormat | + createFileFormat | + locationSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* #createTableLike + | replaceTableHeader ('(' colTypeList ')')? tableProvider? + createTableClauses + (AS? query)? #replaceTable + | ANALYZE TABLE multipartIdentifier partitionSpec? COMPUTE STATISTICS + (identifier | FOR COLUMNS identifierSeq | FOR ALL COLUMNS)? #analyze + | ANALYZE TABLES ((FROM | IN) multipartIdentifier)? COMPUTE STATISTICS + (identifier)? #analyzeTables + | ALTER TABLE multipartIdentifier + ADD (COLUMN | COLUMNS) + columns=qualifiedColTypeWithPositionList #addTableColumns + | ALTER TABLE multipartIdentifier + ADD (COLUMN | COLUMNS) + '(' columns=qualifiedColTypeWithPositionList ')' #addTableColumns + | ALTER TABLE table=multipartIdentifier + RENAME COLUMN + from=multipartIdentifier TO to=errorCapturingIdentifier #renameTableColumn + | ALTER TABLE multipartIdentifier + DROP (COLUMN | COLUMNS) + '(' columns=multipartIdentifierList ')' #dropTableColumns + | ALTER TABLE multipartIdentifier + DROP (COLUMN | COLUMNS) columns=multipartIdentifierList #dropTableColumns + | ALTER (TABLE | VIEW) from=multipartIdentifier + RENAME TO to=multipartIdentifier #renameTable + | ALTER (TABLE | VIEW) multipartIdentifier + SET TBLPROPERTIES tablePropertyList #setTableProperties + | ALTER (TABLE | VIEW) multipartIdentifier + UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties + | ALTER TABLE table=multipartIdentifier + (ALTER | CHANGE) COLUMN? column=multipartIdentifier + alterColumnAction? #alterTableAlterColumn + | ALTER TABLE table=multipartIdentifier partitionSpec? + CHANGE COLUMN? + colName=multipartIdentifier colType colPosition? #hiveChangeColumn + | ALTER TABLE table=multipartIdentifier partitionSpec? + REPLACE COLUMNS + '(' columns=qualifiedColTypeWithPositionList ')' #hiveReplaceColumns + | ALTER TABLE multipartIdentifier (partitionSpec)? + SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe + | ALTER TABLE multipartIdentifier (partitionSpec)? + SET SERDEPROPERTIES tablePropertyList #setTableSerDe + | ALTER (TABLE | VIEW) multipartIdentifier ADD (IF NOT EXISTS)? + partitionSpecLocation+ #addTablePartition + | ALTER TABLE multipartIdentifier + from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition + | ALTER (TABLE | VIEW) multipartIdentifier + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions + | ALTER TABLE multipartIdentifier + (partitionSpec)? SET locationSpec #setTableLocation + | ALTER TABLE multipartIdentifier RECOVER PARTITIONS #recoverPartitions + | DROP TABLE (IF EXISTS)? multipartIdentifier PURGE? #dropTable + | DROP VIEW (IF EXISTS)? multipartIdentifier #dropView + | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? + VIEW (IF NOT EXISTS)? multipartIdentifier + identifierCommentList? + (commentSpec | + (PARTITIONED ON identifierList) | + (TBLPROPERTIES tablePropertyList))* + AS query #createView + | CREATE (OR REPLACE)? GLOBAL? TEMPORARY VIEW + tableIdentifier ('(' colTypeList ')')? tableProvider + (OPTIONS tablePropertyList)? #createTempViewUsing + | ALTER VIEW multipartIdentifier AS? query #alterViewQuery + | CREATE (OR REPLACE)? TEMPORARY? FUNCTION (IF NOT EXISTS)? + multipartIdentifier AS className=STRING + (USING resource (',' resource)*)? #createFunction + | DROP TEMPORARY? FUNCTION (IF EXISTS)? multipartIdentifier #dropFunction + | EXPLAIN (LOGICAL | FORMATTED | EXTENDED | CODEGEN | COST)? + statement #explain + | SHOW TABLES ((FROM | IN) multipartIdentifier)? + (LIKE? pattern=STRING)? #showTables + | SHOW TABLE EXTENDED ((FROM | IN) ns=multipartIdentifier)? + LIKE pattern=STRING partitionSpec? #showTableExtended + | SHOW TBLPROPERTIES table=multipartIdentifier + ('(' key=tablePropertyKey ')')? #showTblProperties + | SHOW COLUMNS (FROM | IN) table=multipartIdentifier + ((FROM | IN) ns=multipartIdentifier)? #showColumns + | SHOW VIEWS ((FROM | IN) multipartIdentifier)? + (LIKE? pattern=STRING)? #showViews + | SHOW PARTITIONS multipartIdentifier partitionSpec? #showPartitions + | SHOW identifier? FUNCTIONS + (LIKE? (multipartIdentifier | pattern=STRING))? #showFunctions + | SHOW CREATE TABLE multipartIdentifier (AS SERDE)? #showCreateTable + | SHOW CURRENT NAMESPACE #showCurrentNamespace + | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction + | (DESC | DESCRIBE) namespace EXTENDED? + multipartIdentifier #describeNamespace + | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)? + multipartIdentifier partitionSpec? describeColName? #describeRelation + | (DESC | DESCRIBE) QUERY? query #describeQuery + | COMMENT ON namespace multipartIdentifier IS + comment=(STRING | NULL) #commentNamespace + | COMMENT ON TABLE multipartIdentifier IS comment=(STRING | NULL) #commentTable + | REFRESH TABLE multipartIdentifier #refreshTable + | REFRESH FUNCTION multipartIdentifier #refreshFunction + | REFRESH (STRING | .*?) #refreshResource + | CACHE LAZY? TABLE multipartIdentifier + (OPTIONS options=tablePropertyList)? (AS? query)? #cacheTable + | UNCACHE TABLE (IF EXISTS)? multipartIdentifier #uncacheTable + | CLEAR CACHE #clearCache + | LOAD DATA LOCAL? INPATH path=STRING OVERWRITE? INTO TABLE + multipartIdentifier partitionSpec? #loadData + | TRUNCATE TABLE multipartIdentifier partitionSpec? #truncateTable + | MSCK REPAIR TABLE multipartIdentifier + (option=(ADD|DROP|SYNC) PARTITIONS)? #repairTable + | op=(ADD | LIST) identifier .*? #manageResource + | SET ROLE .*? #failNativeCommand + | SET TIME ZONE interval #setTimeZone + | SET TIME ZONE timezone=(STRING | LOCAL) #setTimeZone + | SET TIME ZONE .*? #setTimeZone + | SET configKey EQ configValue #setQuotedConfiguration + | SET configKey (EQ .*?)? #setQuotedConfiguration + | SET .*? EQ configValue #setQuotedConfiguration + | SET .*? #setConfiguration + | RESET configKey #resetQuotedConfiguration + | RESET .*? #resetConfiguration + | unsupportedHiveNativeCommands .*? #failNativeCommand + ; + +configKey + : quotedIdentifier + ; + +configValue + : quotedIdentifier + ; + +unsupportedHiveNativeCommands + : kw1=CREATE kw2=ROLE + | kw1=DROP kw2=ROLE + | kw1=GRANT kw2=ROLE? + | kw1=REVOKE kw2=ROLE? + | kw1=SHOW kw2=GRANT + | kw1=SHOW kw2=ROLE kw3=GRANT? + | kw1=SHOW kw2=PRINCIPALS + | kw1=SHOW kw2=ROLES + | kw1=SHOW kw2=CURRENT kw3=ROLES + | kw1=EXPORT kw2=TABLE + | kw1=IMPORT kw2=TABLE + | kw1=SHOW kw2=COMPACTIONS + | kw1=SHOW kw2=CREATE kw3=TABLE + | kw1=SHOW kw2=TRANSACTIONS + | kw1=SHOW kw2=INDEXES + | kw1=SHOW kw2=LOCKS + | kw1=CREATE kw2=INDEX + | kw1=DROP kw2=INDEX + | kw1=ALTER kw2=INDEX + | kw1=LOCK kw2=TABLE + | kw1=LOCK kw2=DATABASE + | kw1=UNLOCK kw2=TABLE + | kw1=UNLOCK kw2=DATABASE + | kw1=CREATE kw2=TEMPORARY kw3=MACRO + | kw1=DROP kw2=TEMPORARY kw3=MACRO + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=CLUSTERED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=CLUSTERED kw4=BY + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SORTED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=SKEWED kw4=BY + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=SKEWED + | kw1=ALTER kw2=TABLE tableIdentifier kw3=NOT kw4=STORED kw5=AS kw6=DIRECTORIES + | kw1=ALTER kw2=TABLE tableIdentifier kw3=SET kw4=SKEWED kw5=LOCATION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=EXCHANGE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=ARCHIVE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=UNARCHIVE kw4=PARTITION + | kw1=ALTER kw2=TABLE tableIdentifier kw3=TOUCH + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=COMPACT + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=CONCATENATE + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=SET kw4=FILEFORMAT + | kw1=ALTER kw2=TABLE tableIdentifier partitionSpec? kw3=REPLACE kw4=COLUMNS + | kw1=START kw2=TRANSACTION + | kw1=COMMIT + | kw1=ROLLBACK + | kw1=DFS + ; + +createTableHeader + : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? multipartIdentifier + ; + +replaceTableHeader + : (CREATE OR)? REPLACE TABLE multipartIdentifier + ; + +bucketSpec + : CLUSTERED BY identifierList + (SORTED BY orderedIdentifierList)? + INTO INTEGER_VALUE BUCKETS + ; + +skewSpec + : SKEWED BY identifierList + ON (constantList | nestedConstantList) + (STORED AS DIRECTORIES)? + ; + +locationSpec + : LOCATION STRING + ; + +commentSpec + : COMMENT STRING + ; + +query + : ctes? queryTerm queryOrganization + ; + +insertInto + : INSERT OVERWRITE TABLE? multipartIdentifier (partitionSpec (IF NOT EXISTS)?)? identifierList? #insertOverwriteTable + | INSERT INTO TABLE? multipartIdentifier partitionSpec? (IF NOT EXISTS)? identifierList? #insertIntoTable + | INSERT OVERWRITE LOCAL? DIRECTORY path=STRING rowFormat? createFileFormat? #insertOverwriteHiveDir + | INSERT OVERWRITE LOCAL? DIRECTORY (path=STRING)? tableProvider (OPTIONS options=tablePropertyList)? #insertOverwriteDir + ; + +partitionSpecLocation + : partitionSpec locationSpec? + ; + +partitionSpec + : PARTITION '(' partitionVal (',' partitionVal)* ')' + ; + +partitionVal + : identifier (EQ constant)? + ; + +namespace + : NAMESPACE + | DATABASE + | SCHEMA + ; + +describeFuncName + : qualifiedName + | STRING + | comparisonOperator + | arithmeticOperator + | predicateOperator + ; + +describeColName + : nameParts+=identifier ('.' nameParts+=identifier)* + ; + +ctes + : WITH namedQuery (',' namedQuery)* + ; + +namedQuery + : name=errorCapturingIdentifier (columnAliases=identifierList)? AS? '(' query ')' + ; + +tableProvider + : USING multipartIdentifier + ; + +createTableClauses + :((OPTIONS options=tablePropertyList) | + (PARTITIONED BY partitioning=partitionFieldList) | + skewSpec | + bucketSpec | + rowFormat | + createFileFormat | + locationSpec | + commentSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* + ; + +tablePropertyList + : '(' tableProperty (',' tableProperty)* ')' + ; + +tableProperty + : key=tablePropertyKey (EQ? value=tablePropertyValue)? + ; + +tablePropertyKey + : identifier ('.' identifier)* + | STRING + ; + +tablePropertyValue + : INTEGER_VALUE + | DECIMAL_VALUE + | booleanValue + | STRING + ; + +constantList + : '(' constant (',' constant)* ')' + ; + +nestedConstantList + : '(' constantList (',' constantList)* ')' + ; + +createFileFormat + : STORED AS fileFormat + | STORED BY storageHandler + ; + +fileFormat + : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING #tableFileFormat + | identifier #genericFileFormat + ; + +storageHandler + : STRING (WITH SERDEPROPERTIES tablePropertyList)? + ; + +resource + : identifier STRING + ; + +dmlStatementNoWith + : insertInto queryTerm queryOrganization #singleInsertQuery + | fromClause multiInsertQueryBody+ #multiInsertQuery + | DELETE FROM multipartIdentifier tableAlias whereClause? #deleteFromTable + | UPDATE multipartIdentifier tableAlias setClause whereClause? #updateTable + | MERGE INTO target=multipartIdentifier targetAlias=tableAlias + USING (source=multipartIdentifier | + '(' sourceQuery=query')') sourceAlias=tableAlias + ON mergeCondition=booleanExpression + matchedClause* + notMatchedClause* #mergeIntoTable + ; + +queryOrganization + : (ORDER BY order+=sortItem (',' order+=sortItem)*)? + (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)? + (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)? + (SORT BY sort+=sortItem (',' sort+=sortItem)*)? + windowClause? + (LIMIT (ALL | limit=expression))? + ; + +multiInsertQueryBody + : insertInto fromStatementBody + ; + +queryTerm + : queryPrimary #queryTermDefault + | left=queryTerm {legacy_setops_precedence_enabled}? + operator=(INTERSECT | UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation + | left=queryTerm {!legacy_setops_precedence_enabled}? + operator=INTERSECT setQuantifier? right=queryTerm #setOperation + | left=queryTerm {!legacy_setops_precedence_enabled}? + operator=(UNION | EXCEPT | SETMINUS) setQuantifier? right=queryTerm #setOperation + ; + +queryPrimary + : querySpecification #queryPrimaryDefault + | fromStatement #fromStmt + | TABLE multipartIdentifier #table + | inlineTable #inlineTableDefault1 + | '(' query ')' #subquery + ; + +sortItem + : expression ordering=(ASC | DESC)? (NULLS nullOrder=(LAST | FIRST))? + ; + +fromStatement + : fromClause fromStatementBody+ + ; + +fromStatementBody + : transformClause + whereClause? + queryOrganization + | selectClause + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? + queryOrganization + ; + +querySpecification + : transformClause + fromClause? + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? #transformQuerySpecification + | selectClause + fromClause? + lateralView* + whereClause? + aggregationClause? + havingClause? + windowClause? #regularQuerySpecification + ; + +transformClause + : (SELECT kind=TRANSFORM '(' setQuantifier? expressionSeq ')' + | kind=MAP setQuantifier? expressionSeq + | kind=REDUCE setQuantifier? expressionSeq) + inRowFormat=rowFormat? + (RECORDWRITER recordWriter=STRING)? + USING script=STRING + (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))? + outRowFormat=rowFormat? + (RECORDREADER recordReader=STRING)? + ; + +selectClause + : SELECT (hints+=hint)* setQuantifier? namedExpressionSeq + ; + +setClause + : SET assignmentList + ; + +matchedClause + : WHEN MATCHED (AND matchedCond=booleanExpression)? THEN matchedAction + ; +notMatchedClause + : WHEN NOT MATCHED (AND notMatchedCond=booleanExpression)? THEN notMatchedAction + ; + +matchedAction + : DELETE + | UPDATE SET ASTERISK + | UPDATE SET assignmentList + ; + +notMatchedAction + : INSERT ASTERISK + | INSERT '(' columns=multipartIdentifierList ')' + VALUES '(' expression (',' expression)* ')' + ; + +assignmentList + : assignment (',' assignment)* + ; + +assignment + : key=multipartIdentifier EQ value=expression + ; + +whereClause + : WHERE booleanExpression + ; + +havingClause + : HAVING booleanExpression + ; + +hint + : '/*+' hintStatements+=hintStatement (','? hintStatements+=hintStatement)* '*/' + ; + +hintStatement + : hintName=identifier + | hintName=identifier '(' parameters+=primaryExpression (',' parameters+=primaryExpression)* ')' + ; + +fromClause + : FROM relation (',' relation)* lateralView* pivotClause? + ; + +temporalClause + : FOR? (SYSTEM_TIME | TIMESTAMP) AS OF timestamp=valueExpression + | FOR? (SYSTEM_VERSION | VERSION) AS OF version=(INTEGER_VALUE | STRING) + ; + +aggregationClause + : GROUP BY groupingExpressionsWithGroupingAnalytics+=groupByClause + (',' groupingExpressionsWithGroupingAnalytics+=groupByClause)* + | GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* ( + WITH kind=ROLLUP + | WITH kind=CUBE + | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')? + ; + +groupByClause + : groupingAnalytics + | expression + ; + +groupingAnalytics + : (ROLLUP | CUBE) '(' groupingSet (',' groupingSet)* ')' + | GROUPING SETS '(' groupingElement (',' groupingElement)* ')' + ; + +groupingElement + : groupingAnalytics + | groupingSet + ; + +groupingSet + : '(' (expression (',' expression)*)? ')' + | expression + ; + +pivotClause + : PIVOT '(' aggregates=namedExpressionSeq FOR pivotColumn IN '(' pivotValues+=pivotValue (',' pivotValues+=pivotValue)* ')' ')' + ; + +pivotColumn + : identifiers+=identifier + | '(' identifiers+=identifier (',' identifiers+=identifier)* ')' + ; + +pivotValue + : expression (AS? identifier)? + ; + +lateralView + : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? + ; + +setQuantifier + : DISTINCT + | ALL + ; + +relation + : LATERAL? relationPrimary joinRelation* + ; + +joinRelation + : (joinType) JOIN LATERAL? right=relationPrimary joinCriteria? + | NATURAL joinType JOIN LATERAL? right=relationPrimary + ; + +joinType + : INNER? + | CROSS + | LEFT OUTER? + | LEFT? SEMI + | RIGHT OUTER? + | FULL OUTER? + | LEFT? ANTI + ; + +joinCriteria + : ON booleanExpression + | USING identifierList + ; + +sample + : TABLESAMPLE '(' sampleMethod? ')' + ; + +sampleMethod + : negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) PERCENTLIT #sampleByPercentile + | expression ROWS #sampleByRows + | sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE + (ON (identifier | qualifiedName '(' ')'))? #sampleByBucket + | bytes=expression #sampleByBytes + ; + +identifierList + : '(' identifierSeq ')' + ; + +identifierSeq + : ident+=errorCapturingIdentifier (',' ident+=errorCapturingIdentifier)* + ; + +orderedIdentifierList + : '(' orderedIdentifier (',' orderedIdentifier)* ')' + ; + +orderedIdentifier + : ident=errorCapturingIdentifier ordering=(ASC | DESC)? + ; + +identifierCommentList + : '(' identifierComment (',' identifierComment)* ')' + ; + +identifierComment + : identifier commentSpec? + ; + +relationPrimary + : multipartIdentifier temporalClause? + sample? tableAlias #tableName + | '(' query ')' sample? tableAlias #aliasedQuery + | '(' relation ')' sample? tableAlias #aliasedRelation + | inlineTable #inlineTableDefault2 + | functionTable #tableValuedFunction + ; + +inlineTable + : VALUES expression (',' expression)* tableAlias + ; + +functionTable + : funcName=functionName '(' (expression (',' expression)*)? ')' tableAlias + ; + +tableAlias + : (AS? strictIdentifier identifierList?)? + ; + +rowFormat + : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde + | ROW FORMAT DELIMITED + (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)? + (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)? + (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)? + (LINES TERMINATED BY linesSeparatedBy=STRING)? + (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited + ; + +multipartIdentifierList + : multipartIdentifier (',' multipartIdentifier)* + ; + +multipartIdentifier + : parts+=errorCapturingIdentifier ('.' parts+=errorCapturingIdentifier)* + ; + +tableIdentifier + : (db=errorCapturingIdentifier '.')? table=errorCapturingIdentifier + ; + +functionIdentifier + : (db=errorCapturingIdentifier '.')? function=errorCapturingIdentifier + ; + +namedExpression + : expression (AS? (name=errorCapturingIdentifier | identifierList))? + ; + +namedExpressionSeq + : namedExpression (',' namedExpression)* + ; + +partitionFieldList + : '(' fields+=partitionField (',' fields+=partitionField)* ')' + ; + +partitionField + : transform #partitionTransform + | colType #partitionColumn + ; + +transform + : qualifiedName #identityTransform + | transformName=identifier + '(' argument+=transformArgument (',' argument+=transformArgument)* ')' #applyTransform + ; + +transformArgument + : qualifiedName + | constant + ; + +expression + : booleanExpression + ; + +expressionSeq + : expression (',' expression)* + ; + +booleanExpression + : NOT booleanExpression #logicalNot + | EXISTS '(' query ')' #exists + | valueExpression predicate? #predicated + | left=booleanExpression operator=AND right=booleanExpression #logicalBinary + | left=booleanExpression operator=OR right=booleanExpression #logicalBinary + ; + +predicate + : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression + | NOT? kind=IN '(' expression (',' expression)* ')' + | NOT? kind=IN '(' query ')' + | NOT? kind=RLIKE pattern=valueExpression + | NOT? kind=LIKE quantifier=(ANY | SOME | ALL) ('('')' | '(' expression (',' expression)* ')') + | NOT? kind=LIKE pattern=valueExpression (ESCAPE escapeChar=STRING)? + | IS NOT? kind=NULL + | IS NOT? kind=(TRUE | FALSE | UNKNOWN) + | IS NOT? kind=DISTINCT FROM right=valueExpression + ; + +valueExpression + : primaryExpression #valueExpressionDefault + | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary + | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary + | left=valueExpression operator=(PLUS | MINUS | CONCAT_PIPE) right=valueExpression #arithmeticBinary + | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary + | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary + | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary + | left=valueExpression comparisonOperator right=valueExpression #comparison + ; + +primaryExpression + : name=(CURRENT_DATE | CURRENT_TIMESTAMP | CURRENT_USER) #currentLike + | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase + | name=(CAST | TRY_CAST) '(' expression AS dataType ')' #cast + | STRUCT '(' (argument+=namedExpression (',' argument+=namedExpression)*)? ')' #struct + | FIRST '(' expression (IGNORE NULLS)? ')' #first + | LAST '(' expression (IGNORE NULLS)? ')' #last + | POSITION '(' substr=valueExpression IN str=valueExpression ')' #position + | constant #constantDefault + | ASTERISK #star + | qualifiedName '.' ASTERISK #star + | '(' namedExpression (',' namedExpression)+ ')' #rowConstructor + | '(' query ')' #subqueryExpression + | functionName '(' (setQuantifier? argument+=expression (',' argument+=expression)*)? ')' + (FILTER '(' WHERE where=booleanExpression ')')? + (nullsOption=(IGNORE | RESPECT) NULLS)? ( OVER windowSpec)? #functionCall + | identifier '->' expression #lambda + | '(' identifier (',' identifier)+ ')' '->' expression #lambda + | value=primaryExpression '[' index=valueExpression ']' #subscript + | identifier #columnReference + | base=primaryExpression '.' fieldName=identifier #dereference + | '(' expression ')' #parenthesizedExpression + | EXTRACT '(' field=identifier FROM source=valueExpression ')' #extract + | (SUBSTR | SUBSTRING) '(' str=valueExpression (FROM | ',') pos=valueExpression + ((FOR | ',') len=valueExpression)? ')' #substring + | TRIM '(' trimOption=(BOTH | LEADING | TRAILING)? (trimStr=valueExpression)? + FROM srcStr=valueExpression ')' #trim + | OVERLAY '(' input=valueExpression PLACING replace=valueExpression + FROM position=valueExpression (FOR length=valueExpression)? ')' #overlay + ; + +constant + : NULL #nullLiteral + | interval #intervalLiteral + | identifier STRING #typeConstructor + | number #numericLiteral + | booleanValue #booleanLiteral + | STRING+ #stringLiteral + ; + +comparisonOperator + : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ + ; + +arithmeticOperator + : PLUS | MINUS | ASTERISK | SLASH | PERCENT | DIV | TILDE | AMPERSAND | PIPE | CONCAT_PIPE | HAT + ; + +predicateOperator + : OR | AND | IN | NOT + ; + +booleanValue + : TRUE | FALSE + ; + +interval + : INTERVAL (errorCapturingMultiUnitsInterval | errorCapturingUnitToUnitInterval)? + ; + +errorCapturingMultiUnitsInterval + : body=multiUnitsInterval unitToUnitInterval? + ; + +multiUnitsInterval + : (intervalValue unit+=identifier)+ + ; + +errorCapturingUnitToUnitInterval + : body=unitToUnitInterval (error1=multiUnitsInterval | error2=unitToUnitInterval)? + ; + +unitToUnitInterval + : value=intervalValue from=identifier TO to=identifier + ; + +intervalValue + : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE | STRING) + ; + +colPosition + : position=FIRST | position=AFTER afterCol=errorCapturingIdentifier + ; + +dataType + : complex=ARRAY '<' dataType '>' #complexDataType + | complex=MAP '<' dataType ',' dataType '>' #complexDataType + | complex=STRUCT ('<' complexColTypeList? '>' | NEQ) #complexDataType + | INTERVAL from=(YEAR | MONTH) (TO to=MONTH)? #yearMonthIntervalDataType + | INTERVAL from=(DAY | HOUR | MINUTE | SECOND) + (TO to=(HOUR | MINUTE | SECOND))? #dayTimeIntervalDataType + | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType + ; + +qualifiedColTypeWithPositionList + : qualifiedColTypeWithPosition (',' qualifiedColTypeWithPosition)* + ; + +qualifiedColTypeWithPosition + : name=multipartIdentifier dataType (NOT NULL)? commentSpec? colPosition? + ; + +colTypeList + : colType (',' colType)* + ; + +colType + : colName=errorCapturingIdentifier dataType (NOT NULL)? commentSpec? + ; + +complexColTypeList + : complexColType (',' complexColType)* + ; + +complexColType + : identifier ':'? dataType (NOT NULL)? commentSpec? + ; + +whenClause + : WHEN condition=expression THEN result=expression + ; + +windowClause + : WINDOW namedWindow (',' namedWindow)* + ; + +namedWindow + : name=errorCapturingIdentifier AS windowSpec + ; + +windowSpec + : name=errorCapturingIdentifier #windowRef + | '('name=errorCapturingIdentifier')' #windowRef + | '(' + ( CLUSTER BY partition+=expression (',' partition+=expression)* + | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)? + ((ORDER | SORT) BY sortItem (',' sortItem)*)?) + windowFrame? + ')' #windowDef + ; + +windowFrame + : frameType=RANGE start=frameBound + | frameType=ROWS start=frameBound + | frameType=RANGE BETWEEN start=frameBound AND end=frameBound + | frameType=ROWS BETWEEN start=frameBound AND end=frameBound + ; + +frameBound + : UNBOUNDED boundType=(PRECEDING | FOLLOWING) + | boundType=CURRENT ROW + | expression boundType=(PRECEDING | FOLLOWING) + ; + +qualifiedNameList + : qualifiedName (',' qualifiedName)* + ; + +functionName + : qualifiedName + | FILTER + | LEFT + | RIGHT + ; + +qualifiedName + : identifier ('.' identifier)* + ; + +// this rule is used for explicitly capturing wrong identifiers such as test-table, which should actually be `test-table` +// replace identifier with errorCapturingIdentifier where the immediate follow symbol is not an expression, otherwise +// valid expressions such as "a-b" can be recognized as an identifier +errorCapturingIdentifier + : identifier errorCapturingIdentifierExtra + ; + +// extra left-factoring grammar +errorCapturingIdentifierExtra + : (MINUS identifier)+ #errorIdent + | #realIdent + ; + +identifier + : strictIdentifier + | {!SQL_standard_keyword_behavior}? strictNonReserved + ; + +strictIdentifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | {SQL_standard_keyword_behavior}? ansiNonReserved #unquotedIdentifier + | {!SQL_standard_keyword_behavior}? nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +number + : {!legacy_exponent_literal_as_decimal_enabled}? MINUS? EXPONENT_VALUE #exponentLiteral + | {!legacy_exponent_literal_as_decimal_enabled}? MINUS? DECIMAL_VALUE #decimalLiteral + | {legacy_exponent_literal_as_decimal_enabled}? MINUS? (EXPONENT_VALUE | DECIMAL_VALUE) #legacyDecimalLiteral + | MINUS? INTEGER_VALUE #integerLiteral + | MINUS? BIGINT_LITERAL #bigIntLiteral + | MINUS? SMALLINT_LITERAL #smallIntLiteral + | MINUS? TINYINT_LITERAL #tinyIntLiteral + | MINUS? DOUBLE_LITERAL #doubleLiteral + | MINUS? FLOAT_LITERAL #floatLiteral + | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral + ; + +alterColumnAction + : TYPE dataType + | commentSpec + | colPosition + | setOrDrop=(SET | DROP) NOT NULL + ; + +// When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. +// - Reserved keywords: +// Keywords that are reserved and can't be used as identifiers for table, view, column, +// function, alias, etc. +// - Non-reserved keywords: +// Keywords that have a special meaning only in particular contexts and can be used as +// identifiers in other contexts. For example, `EXPLAIN SELECT ...` is a command, but EXPLAIN +// can be used as identifiers in other places. +// You can find the full keywords list by searching "Start of the keywords list" in this file. +// The non-reserved keywords are listed below. Keywords not in this list are reserved keywords. +ansiNonReserved +//--ANSI-NON-RESERVED-START + : ADD + | AFTER + | ALTER + | ANALYZE + | ANTI + | ARCHIVE + | ARRAY + | ASC + | AT + | BETWEEN + | BUCKET + | BUCKETS + | BY + | CACHE + | CASCADE + | CHANGE + | CLEAR + | CLUSTER + | CLUSTERED + | CODEGEN + | COLLECTION + | COLUMNS + | COMMENT + | COMMIT + | COMPACT + | COMPACTIONS + | COMPUTE + | CONCATENATE + | COST + | CUBE + | CURRENT + | DATA + | DATABASE + | DATABASES + | DAY + | DBPROPERTIES + | DEFINED + | DELETE + | DELIMITED + | DESC + | DESCRIBE + | DFS + | DIRECTORIES + | DIRECTORY + | DISTRIBUTE + | DIV + | DROP + | ESCAPED + | EXCHANGE + | EXISTS + | EXPLAIN + | EXPORT + | EXTENDED + | EXTERNAL + | EXTRACT + | FIELDS + | FILEFORMAT + | FIRST + | FOLLOWING + | FORMAT + | FORMATTED + | FUNCTION + | FUNCTIONS + | GLOBAL + | GROUPING + | HOUR + | IF + | IGNORE + | IMPORT + | INDEX + | INDEXES + | INPATH + | INPUTFORMAT + | INSERT + | INTERVAL + | ITEMS + | KEYS + | LAST + | LAZY + | LIKE + | LIMIT + | LINES + | LIST + | LOAD + | LOCAL + | LOCATION + | LOCK + | LOCKS + | LOGICAL + | MACRO + | MAP + | MATCHED + | MERGE + | MINUTE + | MONTH + | MSCK + | NAMESPACE + | NAMESPACES + | NO + | NULLS + | OF + | OPTION + | OPTIONS + | OUT + | OUTPUTFORMAT + | OVER + | OVERLAY + | OVERWRITE + | PARTITION + | PARTITIONED + | PARTITIONS + | PERCENTLIT + | PIVOT + | PLACING + | POSITION + | PRECEDING + | PRINCIPALS + | PROPERTIES + | PURGE + | QUERY + | RANGE + | RECORDREADER + | RECORDWRITER + | RECOVER + | REDUCE + | REFRESH + | RENAME + | REPAIR + | REPLACE + | RESET + | RESPECT + | RESTRICT + | REVOKE + | RLIKE + | ROLE + | ROLES + | ROLLBACK + | ROLLUP + | ROW + | ROWS + | SCHEMA + | SECOND + | SEMI + | SEPARATED + | SERDE + | SERDEPROPERTIES + | SET + | SETMINUS + | SETS + | SHOW + | SKEWED + | SORT + | SORTED + | START + | STATISTICS + | STORED + | STRATIFY + | STRUCT + | SUBSTR + | SUBSTRING + | SYNC + | TABLES + | TABLESAMPLE + | TBLPROPERTIES + | TEMPORARY + | TERMINATED + | TOUCH + | TRANSACTION + | TRANSACTIONS + | TRANSFORM + | TRIM + | TRUE + | TRUNCATE + | TRY_CAST + | TYPE + | UNARCHIVE + | UNBOUNDED + | UNCACHE + | UNLOCK + | UNSET + | UPDATE + | USE + | VALUES + | VIEW + | VIEWS + | WINDOW + | YEAR + | ZONE +//--ANSI-NON-RESERVED-END + ; + +// When `SQL_standard_keyword_behavior=false`, there are 2 kinds of keywords in Spark SQL. +// - Non-reserved keywords: +// Same definition as the one when `SQL_standard_keyword_behavior=true`. +// - Strict-non-reserved keywords: +// A strict version of non-reserved keywords, which can not be used as table alias. +// You can find the full keywords list by searching "Start of the keywords list" in this file. +// The strict-non-reserved keywords are listed in `strictNonReserved`. +// The non-reserved keywords are listed in `nonReserved`. +// These 2 together contain all the keywords. +strictNonReserved + : ANTI + | CROSS + | EXCEPT + | FULL + | INNER + | INTERSECT + | JOIN + | LATERAL + | LEFT + | NATURAL + | ON + | RIGHT + | SEMI + | SETMINUS + | UNION + | USING + ; + +nonReserved +//--DEFAULT-NON-RESERVED-START + : ADD + | AFTER + | ALL + | ALTER + | ANALYZE + | AND + | ANY + | ARCHIVE + | ARRAY + | AS + | ASC + | AT + | AUTHORIZATION + | BETWEEN + | BOTH + | BUCKET + | BUCKETS + | BY + | CACHE + | CASCADE + | CASE + | CAST + | CHANGE + | CHECK + | CLEAR + | CLUSTER + | CLUSTERED + | CODEGEN + | COLLATE + | COLLECTION + | COLUMN + | COLUMNS + | COMMENT + | COMMIT + | COMPACT + | COMPACTIONS + | COMPUTE + | CONCATENATE + | CONSTRAINT + | COST + | CREATE + | CUBE + | CURRENT + | CURRENT_DATE + | CURRENT_TIME + | CURRENT_TIMESTAMP + | CURRENT_USER + | DATA + | DATABASE + | DATABASES + | DAY + | DBPROPERTIES + | DEFINED + | DELETE + | DELIMITED + | DESC + | DESCRIBE + | DFS + | DIRECTORIES + | DIRECTORY + | DISTINCT + | DISTRIBUTE + | DIV + | DROP + | ELSE + | END + | ESCAPE + | ESCAPED + | EXCHANGE + | EXISTS + | EXPLAIN + | EXPORT + | EXTENDED + | EXTERNAL + | EXTRACT + | FALSE + | FETCH + | FILTER + | FIELDS + | FILEFORMAT + | FIRST + | FOLLOWING + | FOR + | FOREIGN + | FORMAT + | FORMATTED + | FROM + | FUNCTION + | FUNCTIONS + | GLOBAL + | GRANT + | GROUP + | GROUPING + | HAVING + | HOUR + | IF + | IGNORE + | IMPORT + | IN + | INDEX + | INDEXES + | INPATH + | INPUTFORMAT + | INSERT + | INTERVAL + | INTO + | IS + | ITEMS + | KEYS + | LAST + | LAZY + | LEADING + | LIKE + | LIMIT + | LINES + | LIST + | LOAD + | LOCAL + | LOCATION + | LOCK + | LOCKS + | LOGICAL + | MACRO + | MAP + | MATCHED + | MERGE + | MINUTE + | MONTH + | MSCK + | NAMESPACE + | NAMESPACES + | NO + | NOT + | NULL + | NULLS + | OF + | ONLY + | OPTION + | OPTIONS + | OR + | ORDER + | OUT + | OUTER + | OUTPUTFORMAT + | OVER + | OVERLAPS + | OVERLAY + | OVERWRITE + | PARTITION + | PARTITIONED + | PARTITIONS + | PERCENTLIT + | PIVOT + | PLACING + | POSITION + | PRECEDING + | PRIMARY + | PRINCIPALS + | PROPERTIES + | PURGE + | QUERY + | RANGE + | RECORDREADER + | RECORDWRITER + | RECOVER + | REDUCE + | REFERENCES + | REFRESH + | RENAME + | REPAIR + | REPLACE + | RESET + | RESPECT + | RESTRICT + | REVOKE + | RLIKE + | ROLE + | ROLES + | ROLLBACK + | ROLLUP + | ROW + | ROWS + | SCHEMA + | SECOND + | SELECT + | SEPARATED + | SERDE + | SERDEPROPERTIES + | SESSION_USER + | SET + | SETS + | SHOW + | SKEWED + | SOME + | SORT + | SORTED + | START + | STATISTICS + | STORED + | STRATIFY + | STRUCT + | SUBSTR + | SUBSTRING + | SYNC + | TABLE + | TABLES + | TABLESAMPLE + | TBLPROPERTIES + | TEMPORARY + | TERMINATED + | THEN + | TIME + | TO + | TOUCH + | TRAILING + | TRANSACTION + | TRANSACTIONS + | TRANSFORM + | TRIM + | TRUE + | TRUNCATE + | TRY_CAST + | TYPE + | UNARCHIVE + | UNBOUNDED + | UNCACHE + | UNIQUE + | UNKNOWN + | UNLOCK + | UNSET + | UPDATE + | USE + | USER + | VALUES + | VIEW + | VIEWS + | WHEN + | WHERE + | WINDOW + | WITH + | YEAR + | ZONE + | SYSTEM_VERSION + | VERSION + | SYSTEM_TIME + | TIMESTAMP +//--DEFAULT-NON-RESERVED-END + ; + +// NOTE: If you add a new token in the list below, you should update the list of keywords +// and reserved tag in `docs/sql-ref-ansi-compliance.md#sql-keywords`. + +//============================ +// Start of the keywords list +//============================ +//--SPARK-KEYWORD-LIST-START +ADD: 'ADD'; +AFTER: 'AFTER'; +ALL: 'ALL'; +ALTER: 'ALTER'; +ANALYZE: 'ANALYZE'; +AND: 'AND'; +ANTI: 'ANTI'; +ANY: 'ANY'; +ARCHIVE: 'ARCHIVE'; +ARRAY: 'ARRAY'; +AS: 'AS'; +ASC: 'ASC'; +AT: 'AT'; +AUTHORIZATION: 'AUTHORIZATION'; +BETWEEN: 'BETWEEN'; +BOTH: 'BOTH'; +BUCKET: 'BUCKET'; +BUCKETS: 'BUCKETS'; +BY: 'BY'; +CACHE: 'CACHE'; +CASCADE: 'CASCADE'; +CASE: 'CASE'; +CAST: 'CAST'; +CHANGE: 'CHANGE'; +CHECK: 'CHECK'; +CLEAR: 'CLEAR'; +CLUSTER: 'CLUSTER'; +CLUSTERED: 'CLUSTERED'; +CODEGEN: 'CODEGEN'; +COLLATE: 'COLLATE'; +COLLECTION: 'COLLECTION'; +COLUMN: 'COLUMN'; +COLUMNS: 'COLUMNS'; +COMMENT: 'COMMENT'; +COMMIT: 'COMMIT'; +COMPACT: 'COMPACT'; +COMPACTIONS: 'COMPACTIONS'; +COMPUTE: 'COMPUTE'; +CONCATENATE: 'CONCATENATE'; +CONSTRAINT: 'CONSTRAINT'; +COST: 'COST'; +CREATE: 'CREATE'; +CROSS: 'CROSS'; +CUBE: 'CUBE'; +CURRENT: 'CURRENT'; +CURRENT_DATE: 'CURRENT_DATE'; +CURRENT_TIME: 'CURRENT_TIME'; +CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +CURRENT_USER: 'CURRENT_USER'; +DAY: 'DAY'; +DATA: 'DATA'; +DATABASE: 'DATABASE'; +DATABASES: 'DATABASES' | 'SCHEMAS'; +DBPROPERTIES: 'DBPROPERTIES'; +DEFINED: 'DEFINED'; +DELETE: 'DELETE'; +DELIMITED: 'DELIMITED'; +DESC: 'DESC'; +DESCRIBE: 'DESCRIBE'; +DFS: 'DFS'; +DIRECTORIES: 'DIRECTORIES'; +DIRECTORY: 'DIRECTORY'; +DISTINCT: 'DISTINCT'; +DISTRIBUTE: 'DISTRIBUTE'; +DIV: 'DIV'; +DROP: 'DROP'; +ELSE: 'ELSE'; +END: 'END'; +ESCAPE: 'ESCAPE'; +ESCAPED: 'ESCAPED'; +EXCEPT: 'EXCEPT'; +EXCHANGE: 'EXCHANGE'; +EXISTS: 'EXISTS'; +EXPLAIN: 'EXPLAIN'; +EXPORT: 'EXPORT'; +EXTENDED: 'EXTENDED'; +EXTERNAL: 'EXTERNAL'; +EXTRACT: 'EXTRACT'; +FALSE: 'FALSE'; +FETCH: 'FETCH'; +FIELDS: 'FIELDS'; +FILTER: 'FILTER'; +FILEFORMAT: 'FILEFORMAT'; +FIRST: 'FIRST'; +FOLLOWING: 'FOLLOWING'; +FOR: 'FOR'; +FOREIGN: 'FOREIGN'; +FORMAT: 'FORMAT'; +FORMATTED: 'FORMATTED'; +FROM: 'FROM'; +FULL: 'FULL'; +FUNCTION: 'FUNCTION'; +FUNCTIONS: 'FUNCTIONS'; +GLOBAL: 'GLOBAL'; +GRANT: 'GRANT'; +GROUP: 'GROUP'; +GROUPING: 'GROUPING'; +HAVING: 'HAVING'; +HOUR: 'HOUR'; +IF: 'IF'; +IGNORE: 'IGNORE'; +IMPORT: 'IMPORT'; +IN: 'IN'; +INDEX: 'INDEX'; +INDEXES: 'INDEXES'; +INNER: 'INNER'; +INPATH: 'INPATH'; +INPUTFORMAT: 'INPUTFORMAT'; +INSERT: 'INSERT'; +INTERSECT: 'INTERSECT'; +INTERVAL: 'INTERVAL'; +INTO: 'INTO'; +IS: 'IS'; +ITEMS: 'ITEMS'; +JOIN: 'JOIN'; +KEYS: 'KEYS'; +LAST: 'LAST'; +LATERAL: 'LATERAL'; +LAZY: 'LAZY'; +LEADING: 'LEADING'; +LEFT: 'LEFT'; +LIKE: 'LIKE'; +LIMIT: 'LIMIT'; +LINES: 'LINES'; +LIST: 'LIST'; +LOAD: 'LOAD'; +LOCAL: 'LOCAL'; +LOCATION: 'LOCATION'; +LOCK: 'LOCK'; +LOCKS: 'LOCKS'; +LOGICAL: 'LOGICAL'; +MACRO: 'MACRO'; +MAP: 'MAP'; +MATCHED: 'MATCHED'; +MERGE: 'MERGE'; +MINUTE: 'MINUTE'; +MONTH: 'MONTH'; +MSCK: 'MSCK'; +NAMESPACE: 'NAMESPACE'; +NAMESPACES: 'NAMESPACES'; +NATURAL: 'NATURAL'; +NO: 'NO'; +NOT: 'NOT' | '!'; +NULL: 'NULL'; +NULLS: 'NULLS'; +OF: 'OF'; +ON: 'ON'; +ONLY: 'ONLY'; +OPTION: 'OPTION'; +OPTIONS: 'OPTIONS'; +OR: 'OR'; +ORDER: 'ORDER'; +OUT: 'OUT'; +OUTER: 'OUTER'; +OUTPUTFORMAT: 'OUTPUTFORMAT'; +OVER: 'OVER'; +OVERLAPS: 'OVERLAPS'; +OVERLAY: 'OVERLAY'; +OVERWRITE: 'OVERWRITE'; +PARTITION: 'PARTITION'; +PARTITIONED: 'PARTITIONED'; +PARTITIONS: 'PARTITIONS'; +PERCENTLIT: 'PERCENT'; +PIVOT: 'PIVOT'; +PLACING: 'PLACING'; +POSITION: 'POSITION'; +PRECEDING: 'PRECEDING'; +PRIMARY: 'PRIMARY'; +PRINCIPALS: 'PRINCIPALS'; +PROPERTIES: 'PROPERTIES'; +PURGE: 'PURGE'; +QUERY: 'QUERY'; +RANGE: 'RANGE'; +RECORDREADER: 'RECORDREADER'; +RECORDWRITER: 'RECORDWRITER'; +RECOVER: 'RECOVER'; +REDUCE: 'REDUCE'; +REFERENCES: 'REFERENCES'; +REFRESH: 'REFRESH'; +RENAME: 'RENAME'; +REPAIR: 'REPAIR'; +REPLACE: 'REPLACE'; +RESET: 'RESET'; +RESPECT: 'RESPECT'; +RESTRICT: 'RESTRICT'; +REVOKE: 'REVOKE'; +RIGHT: 'RIGHT'; +RLIKE: 'RLIKE' | 'REGEXP'; +ROLE: 'ROLE'; +ROLES: 'ROLES'; +ROLLBACK: 'ROLLBACK'; +ROLLUP: 'ROLLUP'; +ROW: 'ROW'; +ROWS: 'ROWS'; +SECOND: 'SECOND'; +SCHEMA: 'SCHEMA'; +SELECT: 'SELECT'; +SEMI: 'SEMI'; +SEPARATED: 'SEPARATED'; +SERDE: 'SERDE'; +SERDEPROPERTIES: 'SERDEPROPERTIES'; +SESSION_USER: 'SESSION_USER'; +SET: 'SET'; +SETMINUS: 'MINUS'; +SETS: 'SETS'; +SHOW: 'SHOW'; +SKEWED: 'SKEWED'; +SOME: 'SOME'; +SORT: 'SORT'; +SORTED: 'SORTED'; +START: 'START'; +STATISTICS: 'STATISTICS'; +STORED: 'STORED'; +STRATIFY: 'STRATIFY'; +STRUCT: 'STRUCT'; +SUBSTR: 'SUBSTR'; +SUBSTRING: 'SUBSTRING'; +SYNC: 'SYNC'; +TABLE: 'TABLE'; +TABLES: 'TABLES'; +TABLESAMPLE: 'TABLESAMPLE'; +TBLPROPERTIES: 'TBLPROPERTIES'; +TEMPORARY: 'TEMPORARY' | 'TEMP'; +TERMINATED: 'TERMINATED'; +THEN: 'THEN'; +TIME: 'TIME'; +TO: 'TO'; +TOUCH: 'TOUCH'; +TRAILING: 'TRAILING'; +TRANSACTION: 'TRANSACTION'; +TRANSACTIONS: 'TRANSACTIONS'; +TRANSFORM: 'TRANSFORM'; +TRIM: 'TRIM'; +TRUE: 'TRUE'; +TRUNCATE: 'TRUNCATE'; +TRY_CAST: 'TRY_CAST'; +TYPE: 'TYPE'; +UNARCHIVE: 'UNARCHIVE'; +UNBOUNDED: 'UNBOUNDED'; +UNCACHE: 'UNCACHE'; +UNION: 'UNION'; +UNIQUE: 'UNIQUE'; +UNKNOWN: 'UNKNOWN'; +UNLOCK: 'UNLOCK'; +UNSET: 'UNSET'; +UPDATE: 'UPDATE'; +USE: 'USE'; +USER: 'USER'; +USING: 'USING'; +VALUES: 'VALUES'; +VIEW: 'VIEW'; +VIEWS: 'VIEWS'; +WHEN: 'WHEN'; +WHERE: 'WHERE'; +WINDOW: 'WINDOW'; +WITH: 'WITH'; +YEAR: 'YEAR'; +ZONE: 'ZONE'; + +SYSTEM_VERSION: 'SYSTEM_VERSION'; +VERSION: 'VERSION'; +SYSTEM_TIME: 'SYSTEM_TIME'; +TIMESTAMP: 'TIMESTAMP'; +//--SPARK-KEYWORD-LIST-END +//============================ +// End of the keywords list +//============================ + +EQ : '=' | '=='; +NSEQ: '<=>'; +NEQ : '<>'; +NEQJ: '!='; +LT : '<'; +LTE : '<=' | '!>'; +GT : '>'; +GTE : '>=' | '!<'; + +PLUS: '+'; +MINUS: '-'; +ASTERISK: '*'; +SLASH: '/'; +PERCENT: '%'; +TILDE: '~'; +AMPERSAND: '&'; +PIPE: '|'; +CONCAT_PIPE: '||'; +HAT: '^'; + +STRING + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '"' ( ~('"'|'\\') | ('\\' .) )* '"' + ; + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +EXPONENT_VALUE + : DIGIT+ EXPONENT + | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? + ; + +DECIMAL_VALUE + : DECIMAL_DIGITS {isValidDecimal()}? + ; + +FLOAT_LITERAL + : DIGIT+ EXPONENT? 'F' + | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}? + ; + +DOUBLE_LITERAL + : DIGIT+ EXPONENT? 'D' + | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? + ; + +BIGDECIMAL_LITERAL + : DIGIT+ EXPONENT? 'BD' + | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' {!isHint()}? (BRACKETED_COMMENT|.)*? '*/' -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/hudi-spark-datasource/hudi-spark3/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 b/hudi-spark-datasource/hudi-spark3/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 new file mode 100644 index 0000000000000..0ee6bee1970c1 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3/src/main/antlr4/org/apache/hudi/spark/sql/parser/HoodieSqlBase.g4 @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +grammar HoodieSqlBase; + +import SqlBase; + +singleStatement + : statement EOF + ; + +statement + : query #queryStatement + | ctes? dmlStatementNoWith #dmlStatement + | createTableHeader ('(' colTypeList ')')? tableProvider? + createTableClauses + (AS? query)? #createTable + | replaceTableHeader ('(' colTypeList ')')? tableProvider? + createTableClauses + (AS? query)? #replaceTable + | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? + VIEW (IF NOT EXISTS)? multipartIdentifier + identifierCommentList? + (commentSpec | + (PARTITIONED ON identifierList) | + (TBLPROPERTIES tablePropertyList))* + AS query #createView + | ALTER VIEW multipartIdentifier AS? query #alterViewQuery + | (DESC | DESCRIBE) QUERY? query #describeQuery + | CACHE LAZY? TABLE multipartIdentifier + (OPTIONS options=tablePropertyList)? (AS? query)? #cacheTable + | .*? #passThrough + ; diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/AnalysisException.scala new file mode 100644 index 0000000000000..1482776fd16e8 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.Origin + +/** + * Thrown when a query fails to analyze, usually because the query itself is invalid. + * + * @since 1.3.0 + */ +@Stable +class AnalysisException protected[sql]( + val message: String, + val line: Option[Int] = None, + val startPosition: Option[Int] = None, + // Some plans fail to serialize due to bugs in scala collections. + @transient val plan: Option[LogicalPlan] = None, + val cause: Option[Throwable] = None, + val errorClass: Option[String] = None, + val messageParameters: Array[String] = Array.empty) + extends Exception(message, cause.orNull) with Serializable { + + def this(errorClass: String, messageParameters: Array[String], cause: Option[Throwable]) = + this( + SparkThrowableHelper.getMessage(errorClass, messageParameters), + errorClass = Some(errorClass), + messageParameters = messageParameters, + cause = cause) + + def this(errorClass: String, messageParameters: Array[String]) = + this(errorClass = errorClass, messageParameters = messageParameters, cause = None) + + def this( + errorClass: String, + messageParameters: Array[String], + origin: Origin) = + this( + SparkThrowableHelper.getMessage(errorClass, messageParameters), + line = origin.line, + startPosition = origin.startPosition, + errorClass = Some(errorClass), + messageParameters = messageParameters) + + def copy( + message: String = this.message, + line: Option[Int] = this.line, + startPosition: Option[Int] = this.startPosition, + plan: Option[LogicalPlan] = this.plan, + cause: Option[Throwable] = this.cause, + errorClass: Option[String] = this.errorClass, + messageParameters: Array[String] = this.messageParameters): AnalysisException = + new AnalysisException(message, line, startPosition, plan, cause, errorClass, messageParameters) + + def withPosition(line: Option[Int], startPosition: Option[Int]): AnalysisException = { + val newException = this.copy(line = line, startPosition = startPosition) + newException.setStackTrace(getStackTrace) + newException + } + + override def getMessage: String = { + val planAnnotation = Option(plan).flatten.map(p => s";\n$p").getOrElse("") + getSimpleMessage + planAnnotation + } + + // Outputs an exception without the logical plan. + // For testing only + def getSimpleMessage: String = if (line.isDefined || startPosition.isDefined) { + val lineAnnotation = line.map(l => s" line $l").getOrElse("") + val positionAnnotation = startPosition.map(p => s" pos $p").getOrElse("") + s"$message;$lineAnnotation$positionAnnotation" + } else { + message + } + + def getErrorClass: String = errorClass.orNull + def getSqlState: String = SparkThrowableHelper.getSqlState(errorClass.orNull) +} diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/ErrorInfo.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/ErrorInfo.scala new file mode 100644 index 0000000000000..6ca2af6669cd5 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/ErrorInfo.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.net.URL + +import scala.collection.immutable.SortedMap + +import com.fasterxml.jackson.annotation.JsonIgnore +import com.fasterxml.jackson.core.`type`.TypeReference +import com.fasterxml.jackson.databind.json.JsonMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule + +import org.apache.spark.util.Utils + +/** + * Information associated with an error class. + * + * @param sqlState SQLSTATE associated with this class. + * @param message C-style message format compatible with printf. + * The error message is constructed by concatenating the lines with newlines. + */ +private[spark] case class ErrorInfo(message: Seq[String], sqlState: Option[String]) { + // For compatibility with multi-line error messages + @JsonIgnore + val messageFormat: String = message.mkString("\n") +} + +/** + * Companion object used by instances of [[SparkThrowable]] to access error class information and + * construct error messages. + */ +private[spark] object SparkThrowableHelper { + val errorClassesUrl: URL = + Utils.getSparkClassLoader.getResource("error/error-classes.json") + val errorClassToInfoMap: SortedMap[String, ErrorInfo] = { + val mapper: JsonMapper = JsonMapper.builder() + .addModule(DefaultScalaModule) + .build() + mapper.readValue(errorClassesUrl, new TypeReference[SortedMap[String, ErrorInfo]]() {}) + } + + def getMessage(errorClass: String, messageParameters: Array[String]): String = { + val errorInfo = errorClassToInfoMap.getOrElse(errorClass, + throw new IllegalArgumentException(s"Cannot find error class '$errorClass'")) + String.format(errorInfo.messageFormat, messageParameters: _*) + } + + def getSqlState(errorClass: String): String = { + Option(errorClass).flatMap(errorClassToInfoMap.get).flatMap(_.sqlState).orNull + } +} diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala new file mode 100644 index 0000000000000..0022f0c7b1efd --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/adapter/Spark3_2Adapter.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.adapter + +import org.apache.hudi.Spark3RowSerDe +import org.apache.hudi.client.utils.SparkRowSerDe +import org.apache.hudi.spark3.internal.ReflectUtil +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Expression, Like} +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, Join, JoinHint, LogicalPlan} +import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ +import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile, Spark3ParsePartitionUtil, SparkParsePartitionUtil} +import org.apache.spark.sql.hudi.SparkAdapter +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.parser.HoodieSpark3_2ExtendedSqlParser +import org.apache.spark.sql.{Row, SparkSession} + +/** + * The adapter for spark3. + */ +class Spark3_2Adapter extends SparkAdapter { + + override def createSparkRowSerDe(encoder: ExpressionEncoder[Row]): SparkRowSerDe = { + new Spark3RowSerDe(encoder) + } + + override def toTableIdentifier(aliasId: AliasIdentifier): TableIdentifier = { + aliasId match { + case AliasIdentifier(name, Seq(database)) => + TableIdentifier(name, Some(database)) + case AliasIdentifier(name, Seq(_, database)) => + TableIdentifier(name, Some(database)) + case AliasIdentifier(name, Seq()) => + TableIdentifier(name, None) + case _=> throw new IllegalArgumentException(s"Cannot cast $aliasId to TableIdentifier") + } + } + + override def toTableIdentifier(relation: UnresolvedRelation): TableIdentifier = { + relation.multipartIdentifier.asTableIdentifier + } + + override def createJoin(left: LogicalPlan, right: LogicalPlan, joinType: JoinType): Join = { + Join(left, right, joinType, None, JoinHint.NONE) + } + + override def isInsertInto(plan: LogicalPlan): Boolean = { + plan.isInstanceOf[InsertIntoStatement] + } + + override def getInsertIntoChildren(plan: LogicalPlan): + Option[(LogicalPlan, Map[String, Option[String]], LogicalPlan, Boolean, Boolean)] = { + plan match { + case insert: InsertIntoStatement => + Some((insert.table, insert.partitionSpec, insert.query, insert.overwrite, insert.ifPartitionNotExists)) + case _ => + None + } + } + + override def createInsertInto(table: LogicalPlan, partition: Map[String, Option[String]], + query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean): LogicalPlan = { + ReflectUtil.createInsertInto(table, partition, Seq.empty[String], query, overwrite, ifPartitionNotExists) + } + + override def createSparkParsePartitionUtil(conf: SQLConf): SparkParsePartitionUtil = { + new Spark3ParsePartitionUtil(conf) + } + + override def createLike(left: Expression, right: Expression): Expression = { + new Like(left, right) + } + + override def parseMultipartIdentifier(parser: ParserInterface, sqlText: String): Seq[String] = { + parser.parseMultipartIdentifier(sqlText) + } + + override def createExtendedSparkParser: Option[(SparkSession, ParserInterface) => ParserInterface] = { + Some( + (spark: SparkSession, delegate: ParserInterface) => new HoodieSpark3_2ExtendedSqlParser(spark, delegate) + ) + } + + /** + * Combine [[PartitionedFile]] to [[FilePartition]] according to `maxSplitBytes`. + */ + override def getFilePartitions( + sparkSession: SparkSession, + partitionedFiles: Seq[PartitionedFile], + maxSplitBytes: Long): Seq[FilePartition] = { + FilePartition.getFilePartitions(sparkSession, partitionedFiles, maxSplitBytes) + } +} diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlAstBuilder.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlAstBuilder.scala new file mode 100644 index 0000000000000..0524b073a1a91 --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlAstBuilder.scala @@ -0,0 +1,3871 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.parser + +import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} +import org.antlr.v4.runtime.{ParserRuleContext, Token} +import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser._ +import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseVisitor, HoodieSqlBaseParser} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} +import org.apache.spark.sql.catalyst.parser.ParserUtils.{EnhancedLogicalPlan, checkDuplicateClauses, checkDuplicateKeys, entry, escapedIdentifier, operationNotAllowed, source, string, stringWithoutUnescape, validate, withOrigin} +import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.catalyst.util.DateTimeUtils._ +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, IntervalUtils, truncatedString} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition +import org.apache.spark.sql.connector.catalog.{SupportsNamespaces, TableCatalog} +import org.apache.spark.sql.connector.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform, Expression => V2Expression} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.Utils.isTesting +import org.apache.spark.util.random.RandomSampler + +import java.util.Locale +import java.util.concurrent.TimeUnit +import javax.xml.bind.DatatypeConverter +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +/** + * The AstBuilder for HoodieSqlParser to parser the AST tree to Logical Plan. + * Here we only do the parser for the extended sql syntax. e.g MergeInto. For + * other sql syntax we use the delegate sql parser which is the SparkSqlParser. + */ +class HoodieSpark3_2ExtendedSqlAstBuilder(conf: SQLConf, delegate: ParserInterface) + extends HoodieSqlBaseBaseVisitor[AnyRef] with Logging { + + protected def typedVisit[T](ctx: ParseTree): T = { + ctx.accept(this).asInstanceOf[T] + } + + /** + * Override the default behavior for all visit methods. This will only return a non-null result + * when the context has only one child. This is done because there is no generic method to + * combine the results of the context children. In all other cases null is returned. + */ + override def visitChildren(node: RuleNode): AnyRef = { + if (node.getChildCount == 1) { + node.getChild(0).accept(this) + } else { + null + } + } + + /** + * Create an aliased table reference. This is typically used in FROM clauses. + */ + override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { + val tableId = visitMultipartIdentifier(ctx.multipartIdentifier()) + val relation = UnresolvedRelation(tableId) + val table = mayApplyAliasPlan( + ctx.tableAlias, relation.optionalMap(ctx.temporalClause)(withTimeTravel)) + table.optionalMap(ctx.sample)(withSample) + } + + private def withTimeTravel( + ctx: TemporalClauseContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val v = ctx.version + val version = if (ctx.INTEGER_VALUE != null) { + Some(v.getText) + } else { + Option(v).map(string) + } + + val timestamp = Option(ctx.timestamp).map(expression) + if (timestamp.exists(_.references.nonEmpty)) { + throw new ParseException( + "timestamp expression cannot refer to any columns", ctx.timestamp) + } + if (timestamp.exists(e => SubqueryExpression.hasSubquery(e))) { + throw new ParseException( + "timestamp expression cannot contain subqueries", ctx.timestamp) + } + + TimeTravelRelation(plan, timestamp, version) + } + + // ============== The following code is fork from org.apache.spark.sql.catalyst.parser.AstBuilder + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { + visit(ctx.statement).asInstanceOf[LogicalPlan] + } + + override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) { + visitNamedExpression(ctx.namedExpression) + } + + override def visitSingleTableIdentifier( + ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) { + visitTableIdentifier(ctx.tableIdentifier) + } + + override def visitSingleFunctionIdentifier( + ctx: SingleFunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) { + visitFunctionIdentifier(ctx.functionIdentifier) + } + + override def visitSingleMultipartIdentifier( + ctx: SingleMultipartIdentifierContext): Seq[String] = withOrigin(ctx) { + visitMultipartIdentifier(ctx.multipartIdentifier) + } + + override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { + typedVisit[DataType](ctx.dataType) + } + + override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = { + val schema = StructType(visitColTypeList(ctx.colTypeList)) + withOrigin(ctx)(schema) + } + + /* ******************************************************************************************** + * Plan parsing + * ******************************************************************************************** */ + protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree) + + /** + * Create a top-level plan with Common Table Expressions. + */ + override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) { + val query = plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses) + + // Apply CTEs + query.optionalMap(ctx.ctes)(withCTE) + } + + override def visitDmlStatement(ctx: DmlStatementContext): AnyRef = withOrigin(ctx) { + val dmlStmt = plan(ctx.dmlStatementNoWith) + // Apply CTEs + dmlStmt.optionalMap(ctx.ctes)(withCTE) + } + + private def withCTE(ctx: CtesContext, plan: LogicalPlan): LogicalPlan = { + val ctes = ctx.namedQuery.asScala.map { nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) + } + // Check for duplicate names. + val duplicates = ctes.groupBy(_._1).filter(_._2.size > 1).keys + if (duplicates.nonEmpty) { + throw duplicateCteDefinitionNamesError( + duplicates.mkString("'", "', '", "'"), ctx) + } + UnresolvedWith(plan, ctes.toSeq) + } + + /** + * Create a logical query plan for a hive-style FROM statement body. + */ + private def withFromStatementBody( + ctx: FromStatementBodyContext, plan: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // two cases for transforms and selects + if (ctx.transformClause != null) { + withTransformQuerySpecification( + ctx, + ctx.transformClause, + ctx.lateralView, + ctx.whereClause, + ctx.aggregationClause, + ctx.havingClause, + ctx.windowClause, + plan + ) + } else { + withSelectQuerySpecification( + ctx, + ctx.selectClause, + ctx.lateralView, + ctx.whereClause, + ctx.aggregationClause, + ctx.havingClause, + ctx.windowClause, + plan + ) + } + } + + override def visitFromStatement(ctx: FromStatementContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + val selects = ctx.fromStatementBody.asScala.map { body => + withFromStatementBody(body, from). + // Add organization statements. + optionalMap(body.queryOrganization)(withQueryResultClauses) + } + // If there are multiple SELECT just UNION them together into one query. + if (selects.length == 1) { + selects.head + } else { + Union(selects.toSeq) + } + } + + /** + * Create a named logical plan. + * + * This is only used for Common Table Expressions. + */ + override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { + val subQuery: LogicalPlan = plan(ctx.query).optionalMap(ctx.columnAliases)( + (columnAliases, plan) => + UnresolvedSubqueryColumnAliases(visitIdentifierList(columnAliases), plan) + ) + SubqueryAlias(ctx.name.getText, subQuery) + } + + /** + * Create a logical plan which allows for multiple inserts using one 'from' statement. These + * queries have the following SQL form: + * {{{ + * [WITH cte...]? + * FROM src + * [INSERT INTO tbl1 SELECT *]+ + * }}} + * For example: + * {{{ + * FROM db.tbl1 A + * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5 + * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12 + * }}} + * This (Hive) feature cannot be combined with set-operators. + */ + override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + + // Build the insert clauses. + val inserts = ctx.multiInsertQueryBody.asScala.map { body => + withInsertInto(body.insertInto, + withFromStatementBody(body.fromStatementBody, from). + optionalMap(body.fromStatementBody.queryOrganization)(withQueryResultClauses)) + } + + // If there are multiple INSERTS just UNION them together into one query. + if (inserts.length == 1) { + inserts.head + } else { + Union(inserts.toSeq) + } + } + + /** + * Create a logical plan for a regular (single-insert) query. + */ + override def visitSingleInsertQuery( + ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { + withInsertInto( + ctx.insertInto(), + plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses)) + } + + /** + * Parameters used for writing query to a table: + * (UnresolvedRelation, tableColumnList, partitionKeys, ifPartitionNotExists). + */ + type InsertTableParams = (UnresolvedRelation, Seq[String], Map[String, Option[String]], Boolean) + + /** + * Parameters used for writing query to a directory: (isLocal, CatalogStorageFormat, provider). + */ + type InsertDirParams = (Boolean, CatalogStorageFormat, Option[String]) + + /** + * Add an + * {{{ + * INSERT OVERWRITE TABLE tableIdentifier [partitionSpec [IF NOT EXISTS]]? [identifierList] + * INSERT INTO [TABLE] tableIdentifier [partitionSpec] [identifierList] + * INSERT OVERWRITE [LOCAL] DIRECTORY STRING [rowFormat] [createFileFormat] + * INSERT OVERWRITE [LOCAL] DIRECTORY [STRING] tableProvider [OPTIONS tablePropertyList] + * }}} + * operation to logical plan + */ + private def withInsertInto( + ctx: InsertIntoContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + ctx match { + case table: InsertIntoTableContext => + val (relation, cols, partition, ifPartitionNotExists) = visitInsertIntoTable(table) + InsertIntoStatement( + relation, + partition, + cols, + query, + overwrite = false, + ifPartitionNotExists) + case table: InsertOverwriteTableContext => + val (relation, cols, partition, ifPartitionNotExists) = visitInsertOverwriteTable(table) + InsertIntoStatement( + relation, + partition, + cols, + query, + overwrite = true, + ifPartitionNotExists) + case dir: InsertOverwriteDirContext => + val (isLocal, storage, provider) = visitInsertOverwriteDir(dir) + InsertIntoDir(isLocal, storage, provider, query, overwrite = true) + case hiveDir: InsertOverwriteHiveDirContext => + val (isLocal, storage, provider) = visitInsertOverwriteHiveDir(hiveDir) + InsertIntoDir(isLocal, storage, provider, query, overwrite = true) + case _ => + throw invalidInsertIntoError(ctx) + } + } + + /** + * Add an INSERT INTO TABLE operation to the logical plan. + */ + override def visitInsertIntoTable( + ctx: InsertIntoTableContext): InsertTableParams = withOrigin(ctx) { + val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil) + val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + + if (ctx.EXISTS != null) { + operationNotAllowed("INSERT INTO ... IF NOT EXISTS", ctx) + } + + (createUnresolvedRelation(ctx.multipartIdentifier), cols, partitionKeys, false) + } + + /** + * Add an INSERT OVERWRITE TABLE operation to the logical plan. + */ + override def visitInsertOverwriteTable( + ctx: InsertOverwriteTableContext): InsertTableParams = withOrigin(ctx) { + assert(ctx.OVERWRITE() != null) + val cols = Option(ctx.identifierList()).map(visitIdentifierList).getOrElse(Nil) + val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + + val dynamicPartitionKeys: Map[String, Option[String]] = partitionKeys.filter(_._2.isEmpty) + if (ctx.EXISTS != null && dynamicPartitionKeys.nonEmpty) { + operationNotAllowed("IF NOT EXISTS with dynamic partitions: " + + dynamicPartitionKeys.keys.mkString(", "), ctx) + } + + (createUnresolvedRelation(ctx.multipartIdentifier), cols, partitionKeys, ctx.EXISTS() != null) + } + + /** + * Write to a directory, returning a [[InsertIntoDir]] logical plan. + */ + override def visitInsertOverwriteDir( + ctx: InsertOverwriteDirContext): InsertDirParams = withOrigin(ctx) { + throw insertOverwriteDirectoryUnsupportedError(ctx) + } + + /** + * Write to a directory, returning a [[InsertIntoDir]] logical plan. + */ + override def visitInsertOverwriteHiveDir( + ctx: InsertOverwriteHiveDirContext): InsertDirParams = withOrigin(ctx) { + throw insertOverwriteDirectoryUnsupportedError(ctx) + } + + private def getTableAliasWithoutColumnAlias( + ctx: TableAliasContext, op: String): Option[String] = { + if (ctx == null) { + None + } else { + val ident = ctx.strictIdentifier() + if (ctx.identifierList() != null) { + throw columnAliasInOperationNotAllowedError(op, ctx) + } + if (ident != null) Some(ident.getText) else None + } + } + + override def visitDeleteFromTable( + ctx: DeleteFromTableContext): LogicalPlan = withOrigin(ctx) { + val table = createUnresolvedRelation(ctx.multipartIdentifier()) + val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "DELETE") + val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table) + val predicate = if (ctx.whereClause() != null) { + Some(expression(ctx.whereClause().booleanExpression())) + } else { + None + } + DeleteFromTable(aliasedTable, predicate) + } + + override def visitUpdateTable(ctx: UpdateTableContext): LogicalPlan = withOrigin(ctx) { + val table = createUnresolvedRelation(ctx.multipartIdentifier()) + val tableAlias = getTableAliasWithoutColumnAlias(ctx.tableAlias(), "UPDATE") + val aliasedTable = tableAlias.map(SubqueryAlias(_, table)).getOrElse(table) + val assignments = withAssignments(ctx.setClause().assignmentList()) + val predicate = if (ctx.whereClause() != null) { + Some(expression(ctx.whereClause().booleanExpression())) + } else { + None + } + + UpdateTable(aliasedTable, assignments, predicate) + } + + private def withAssignments(assignCtx: AssignmentListContext): Seq[Assignment] = + withOrigin(assignCtx) { + assignCtx.assignment().asScala.map { assign => + Assignment(UnresolvedAttribute(visitMultipartIdentifier(assign.key)), + expression(assign.value)) + }.toSeq + } + + override def visitMergeIntoTable(ctx: MergeIntoTableContext): LogicalPlan = withOrigin(ctx) { + val targetTable = createUnresolvedRelation(ctx.target) + val targetTableAlias = getTableAliasWithoutColumnAlias(ctx.targetAlias, "MERGE") + val aliasedTarget = targetTableAlias.map(SubqueryAlias(_, targetTable)).getOrElse(targetTable) + + val sourceTableOrQuery = if (ctx.source != null) { + createUnresolvedRelation(ctx.source) + } else if (ctx.sourceQuery != null) { + visitQuery(ctx.sourceQuery) + } else { + throw emptySourceForMergeError(ctx) + } + val sourceTableAlias = getTableAliasWithoutColumnAlias(ctx.sourceAlias, "MERGE") + val aliasedSource = + sourceTableAlias.map(SubqueryAlias(_, sourceTableOrQuery)).getOrElse(sourceTableOrQuery) + + val mergeCondition = expression(ctx.mergeCondition) + + val matchedActions = ctx.matchedClause().asScala.map { + clause => { + if (clause.matchedAction().DELETE() != null) { + DeleteAction(Option(clause.matchedCond).map(expression)) + } else if (clause.matchedAction().UPDATE() != null) { + val condition = Option(clause.matchedCond).map(expression) + if (clause.matchedAction().ASTERISK() != null) { + UpdateStarAction(condition) + } else { + UpdateAction(condition, withAssignments(clause.matchedAction().assignmentList())) + } + } else { + // It should not be here. + throw unrecognizedMatchedActionError(clause) + } + } + } + val notMatchedActions = ctx.notMatchedClause().asScala.map { + clause => { + if (clause.notMatchedAction().INSERT() != null) { + val condition = Option(clause.notMatchedCond).map(expression) + if (clause.notMatchedAction().ASTERISK() != null) { + InsertStarAction(condition) + } else { + val columns = clause.notMatchedAction().columns.multipartIdentifier() + .asScala.map(attr => UnresolvedAttribute(visitMultipartIdentifier(attr))) + val values = clause.notMatchedAction().expression().asScala.map(expression) + if (columns.size != values.size) { + throw insertedValueNumberNotMatchFieldNumberError(clause) + } + InsertAction(condition, columns.zip(values).map(kv => Assignment(kv._1, kv._2)).toSeq) + } + } else { + // It should not be here. + throw unrecognizedNotMatchedActionError(clause) + } + } + } + if (matchedActions.isEmpty && notMatchedActions.isEmpty) { + throw mergeStatementWithoutWhenClauseError(ctx) + } + // children being empty means that the condition is not set + val matchedActionSize = matchedActions.length + if (matchedActionSize >= 2 && !matchedActions.init.forall(_.condition.nonEmpty)) { + throw nonLastMatchedClauseOmitConditionError(ctx) + } + val notMatchedActionSize = notMatchedActions.length + if (notMatchedActionSize >= 2 && !notMatchedActions.init.forall(_.condition.nonEmpty)) { + throw nonLastNotMatchedClauseOmitConditionError(ctx) + } + + MergeIntoTable( + aliasedTarget, + aliasedSource, + mergeCondition, + matchedActions.toSeq, + notMatchedActions.toSeq) + } + + /** + * Create a partition specification map. + */ + override def visitPartitionSpec( + ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { + val legacyNullAsString = + conf.getConf(SQLConf.LEGACY_PARSE_NULL_PARTITION_SPEC_AS_STRING_LITERAL) + val parts = ctx.partitionVal.asScala.map { pVal => + val name = pVal.identifier.getText + val value = Option(pVal.constant).map(v => visitStringConstant(v, legacyNullAsString)) + name -> value + } + // Before calling `toMap`, we check duplicated keys to avoid silently ignore partition values + // in partition spec like PARTITION(a='1', b='2', a='3'). The real semantical check for + // partition columns will be done in analyzer. + if (conf.caseSensitiveAnalysis) { + checkDuplicateKeys(parts.toSeq, ctx) + } else { + checkDuplicateKeys(parts.map(kv => kv._1.toLowerCase(Locale.ROOT) -> kv._2).toSeq, ctx) + } + parts.toMap + } + + /** + * Create a partition specification map without optional values. + */ + protected def visitNonOptionalPartitionSpec( + ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { + visitPartitionSpec(ctx).map { + case (key, None) => throw emptyPartitionKeyError(key, ctx) + case (key, Some(value)) => key -> value + } + } + + /** + * Convert a constant of any type into a string. This is typically used in DDL commands, and its + * main purpose is to prevent slight differences due to back to back conversions i.e.: + * String -> Literal -> String. + */ + protected def visitStringConstant( + ctx: ConstantContext, + legacyNullAsString: Boolean): String = withOrigin(ctx) { + expression(ctx) match { + case Literal(null, _) if !legacyNullAsString => null + case l @ Literal(null, _) => l.toString + case l: Literal => + // TODO For v2 commands, we will cast the string back to its actual value, + // which is a waste and can be improved in the future. + Cast(l, StringType, Some(conf.sessionLocalTimeZone)).eval().toString + case other => + throw new IllegalArgumentException(s"Only literals are allowed in the " + + s"partition spec, but got ${other.sql}") + } + } + + /** + * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These + * clauses determine the shape (ordering/partitioning/rows) of the query result. + */ + private def withQueryResultClauses( + ctx: QueryOrganizationContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. + val withOrder = if ( + !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // ORDER BY ... + Sort(order.asScala.map(visitSortItem).toSeq, global = true, query) + } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... + Sort(sort.asScala.map(visitSortItem).toSeq, global = false, query) + } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // DISTRIBUTE BY ... + withRepartitionByExpression(ctx, expressionList(distributeBy), query) + } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... DISTRIBUTE BY ... + Sort( + sort.asScala.map(visitSortItem).toSeq, + global = false, + withRepartitionByExpression(ctx, expressionList(distributeBy), query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { + // CLUSTER BY ... + val expressions = expressionList(clusterBy) + Sort( + expressions.map(SortOrder(_, Ascending)), + global = false, + withRepartitionByExpression(ctx, expressions, query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // [EMPTY] + query + } else { + throw combinationQueryResultClausesUnsupportedError(ctx) + } + + // WINDOWS + val withWindow = withOrder.optionalMap(windowClause)(withWindowClause) + + // LIMIT + // - LIMIT ALL is the same as omitting the LIMIT clause + withWindow.optional(limit) { + Limit(typedVisit(limit), withWindow) + } + } + + /** + * Create a clause for DISTRIBUTE BY. + */ + protected def withRepartitionByExpression( + ctx: QueryOrganizationContext, + expressions: Seq[Expression], + query: LogicalPlan): LogicalPlan = { + throw distributeByUnsupportedError(ctx) + } + + override def visitTransformQuerySpecification( + ctx: TransformQuerySpecificationContext): LogicalPlan = withOrigin(ctx) { + val from = OneRowRelation().optional(ctx.fromClause) { + visitFromClause(ctx.fromClause) + } + withTransformQuerySpecification( + ctx, + ctx.transformClause, + ctx.lateralView, + ctx.whereClause, + ctx.aggregationClause, + ctx.havingClause, + ctx.windowClause, + from + ) + } + + override def visitRegularQuerySpecification( + ctx: RegularQuerySpecificationContext): LogicalPlan = withOrigin(ctx) { + val from = OneRowRelation().optional(ctx.fromClause) { + visitFromClause(ctx.fromClause) + } + withSelectQuerySpecification( + ctx, + ctx.selectClause, + ctx.lateralView, + ctx.whereClause, + ctx.aggregationClause, + ctx.havingClause, + ctx.windowClause, + from + ) + } + + override def visitNamedExpressionSeq( + ctx: NamedExpressionSeqContext): Seq[Expression] = { + Option(ctx).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + } + + override def visitExpressionSeq(ctx: ExpressionSeqContext): Seq[Expression] = { + Option(ctx).toSeq + .flatMap(_.expression.asScala) + .map(typedVisit[Expression]) + } + + /** + * Create a logical plan using a having clause. + */ + private def withHavingClause( + ctx: HavingClauseContext, plan: LogicalPlan): LogicalPlan = { + // Note that we add a cast to non-predicate expressions. If the expression itself is + // already boolean, the optimizer will get rid of the unnecessary cast. + val predicate = expression(ctx.booleanExpression) match { + case p: Predicate => p + case e => Cast(e, BooleanType) + } + UnresolvedHaving(predicate, plan) + } + + /** + * Create a logical plan using a where clause. + */ + private def withWhereClause(ctx: WhereClauseContext, plan: LogicalPlan): LogicalPlan = { + Filter(expression(ctx.booleanExpression), plan) + } + + /** + * Add a hive-style transform (SELECT TRANSFORM/MAP/REDUCE) query specification to a logical plan. + */ + private def withTransformQuerySpecification( + ctx: ParserRuleContext, + transformClause: TransformClauseContext, + lateralView: java.util.List[LateralViewContext], + whereClause: WhereClauseContext, + aggregationClause: AggregationClauseContext, + havingClause: HavingClauseContext, + windowClause: WindowClauseContext, + relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + if (transformClause.setQuantifier != null) { + throw transformNotSupportQuantifierError(transformClause.setQuantifier) + } + // Create the attributes. + val (attributes, schemaLess) = if (transformClause.colTypeList != null) { + // Typed return columns. + (createSchema(transformClause.colTypeList).toAttributes, false) + } else if (transformClause.identifierSeq != null) { + // Untyped return columns. + val attrs = visitIdentifierSeq(transformClause.identifierSeq).map { name => + AttributeReference(name, StringType, nullable = true)() + } + (attrs, false) + } else { + (Seq(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), true) + } + + val plan = visitCommonSelectQueryClausePlan( + relation, + visitExpressionSeq(transformClause.expressionSeq), + lateralView, + whereClause, + aggregationClause, + havingClause, + windowClause, + isDistinct = false) + + ScriptTransformation( + string(transformClause.script), + attributes, + plan, + withScriptIOSchema( + ctx, + transformClause.inRowFormat, + transformClause.recordWriter, + transformClause.outRowFormat, + transformClause.recordReader, + schemaLess + ) + ) + } + + /** + * Add a regular (SELECT) query specification to a logical plan. The query specification + * is the core of the logical plan, this is where sourcing (FROM clause), projection (SELECT), + * aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place. + * + * Note that query hints are ignored (both by the parser and the builder). + */ + private def withSelectQuerySpecification( + ctx: ParserRuleContext, + selectClause: SelectClauseContext, + lateralView: java.util.List[LateralViewContext], + whereClause: WhereClauseContext, + aggregationClause: AggregationClauseContext, + havingClause: HavingClauseContext, + windowClause: WindowClauseContext, + relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val isDistinct = selectClause.setQuantifier() != null && + selectClause.setQuantifier().DISTINCT() != null + + val plan = visitCommonSelectQueryClausePlan( + relation, + visitNamedExpressionSeq(selectClause.namedExpressionSeq), + lateralView, + whereClause, + aggregationClause, + havingClause, + windowClause, + isDistinct) + + // Hint + selectClause.hints.asScala.foldRight(plan)(withHints) + } + + def visitCommonSelectQueryClausePlan( + relation: LogicalPlan, + expressions: Seq[Expression], + lateralView: java.util.List[LateralViewContext], + whereClause: WhereClauseContext, + aggregationClause: AggregationClauseContext, + havingClause: HavingClauseContext, + windowClause: WindowClauseContext, + isDistinct: Boolean): LogicalPlan = { + // Add lateral views. + val withLateralView = lateralView.asScala.foldLeft(relation)(withGenerate) + + // Add where. + val withFilter = withLateralView.optionalMap(whereClause)(withWhereClause) + + // Add aggregation or a project. + val namedExpressions = expressions.map { + case e: NamedExpression => e + case e: Expression => UnresolvedAlias(e) + } + + def createProject() = if (namedExpressions.nonEmpty) { + Project(namedExpressions, withFilter) + } else { + withFilter + } + + val withProject = if (aggregationClause == null && havingClause != null) { + if (conf.getConf(SQLConf.LEGACY_HAVING_WITHOUT_GROUP_BY_AS_WHERE)) { + // If the legacy conf is set, treat HAVING without GROUP BY as WHERE. + val predicate = expression(havingClause.booleanExpression) match { + case p: Predicate => p + case e => Cast(e, BooleanType) + } + Filter(predicate, createProject()) + } else { + // According to SQL standard, HAVING without GROUP BY means global aggregate. + withHavingClause(havingClause, Aggregate(Nil, namedExpressions, withFilter)) + } + } else if (aggregationClause != null) { + val aggregate = withAggregationClause(aggregationClause, namedExpressions, withFilter) + aggregate.optionalMap(havingClause)(withHavingClause) + } else { + // When hitting this branch, `having` must be null. + createProject() + } + + // Distinct + val withDistinct = if (isDistinct) { + Distinct(withProject) + } else { + withProject + } + + // Window + val withWindow = withDistinct.optionalMap(windowClause)(withWindowClause) + + withWindow + } + + // Script Transform's input/output format. + type ScriptIOFormat = + (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) + + protected def getRowFormatDelimited(ctx: RowFormatDelimitedContext): ScriptIOFormat = { + // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema + // expects a seq of pairs in which the old parsers' token names are used as keys. + // Transforming the result of visitRowFormatDelimited would be quite a bit messier than + // retrieving the key value pairs ourselves. + val entries = entry("TOK_TABLEROWFORMATFIELD", ctx.fieldsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATCOLLITEMS", ctx.collectionItemsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATMAPKEYS", ctx.keysTerminatedBy) ++ + entry("TOK_TABLEROWFORMATNULL", ctx.nullDefinedAs) ++ + Option(ctx.linesSeparatedBy).toSeq.map { token => + val value = string(token) + validate( + value == "\n", + s"LINES TERMINATED BY only supports newline '\\n' right now: $value", + ctx) + "TOK_TABLEROWFORMATLINES" -> value + } + + (entries, None, Seq.empty, None) + } + + /** + * Create a [[ScriptInputOutputSchema]]. + */ + protected def withScriptIOSchema( + ctx: ParserRuleContext, + inRowFormat: RowFormatContext, + recordWriter: Token, + outRowFormat: RowFormatContext, + recordReader: Token, + schemaLess: Boolean): ScriptInputOutputSchema = { + + def format(fmt: RowFormatContext): ScriptIOFormat = fmt match { + case c: RowFormatDelimitedContext => + getRowFormatDelimited(c) + + case c: RowFormatSerdeContext => + throw transformWithSerdeUnsupportedError(ctx) + + // SPARK-32106: When there is no definition about format, we return empty result + // to use a built-in default Serde in SparkScriptTransformationExec. + case null => + (Nil, None, Seq.empty, None) + } + + val (inFormat, inSerdeClass, inSerdeProps, reader) = format(inRowFormat) + + val (outFormat, outSerdeClass, outSerdeProps, writer) = format(outRowFormat) + + ScriptInputOutputSchema( + inFormat, outFormat, + inSerdeClass, outSerdeClass, + inSerdeProps, outSerdeProps, + reader, writer, + schemaLess) + } + + /** + * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma + * separated) relations here, these get converted into a single plan by condition-less inner join. + */ + override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { + val from = ctx.relation.asScala.foldLeft(null: LogicalPlan) { (left, relation) => + val right = plan(relation.relationPrimary) + val join = right.optionalMap(left) { (left, right) => + if (relation.LATERAL != null) { + if (!relation.relationPrimary.isInstanceOf[AliasedQueryContext]) { + throw invalidLateralJoinRelationError(relation.relationPrimary) + } + LateralJoin(left, LateralSubquery(right), Inner, None) + } else { + Join(left, right, Inner, None, JoinHint.NONE) + } + } + withJoinRelations(join, relation) + } + if (ctx.pivotClause() != null) { + if (!ctx.lateralView.isEmpty) { + throw lateralWithPivotInFromClauseNotAllowedError(ctx) + } + withPivot(ctx.pivotClause, from) + } else { + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } + } + + /** + * Connect two queries by a Set operator. + * + * Supported Set operators are: + * - UNION [ DISTINCT | ALL ] + * - EXCEPT [ DISTINCT | ALL ] + * - MINUS [ DISTINCT | ALL ] + * - INTERSECT [DISTINCT | ALL] + */ + override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { + val left = plan(ctx.left) + val right = plan(ctx.right) + val all = Option(ctx.setQuantifier()).exists(_.ALL != null) + ctx.operator.getType match { + case HoodieSqlBaseParser.UNION if all => + Union(left, right) + case HoodieSqlBaseParser.UNION => + Distinct(Union(left, right)) + case HoodieSqlBaseParser.INTERSECT if all => + Intersect(left, right, isAll = true) + case HoodieSqlBaseParser.INTERSECT => + Intersect(left, right, isAll = false) + case HoodieSqlBaseParser.EXCEPT if all => + Except(left, right, isAll = true) + case HoodieSqlBaseParser.EXCEPT => + Except(left, right, isAll = false) + case HoodieSqlBaseParser.SETMINUS if all => + Except(left, right, isAll = true) + case HoodieSqlBaseParser.SETMINUS => + Except(left, right, isAll = false) + } + } + + /** + * Add a [[WithWindowDefinition]] operator to a logical plan. + */ + private def withWindowClause( + ctx: WindowClauseContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Collect all window specifications defined in the WINDOW clause. + val baseWindowTuples = ctx.namedWindow.asScala.map { + wCtx => + (wCtx.name.getText, typedVisit[WindowSpec](wCtx.windowSpec)) + } + baseWindowTuples.groupBy(_._1).foreach { kv => + if (kv._2.size > 1) { + throw repetitiveWindowDefinitionError(kv._1, ctx) + } + } + val baseWindowMap = baseWindowTuples.toMap + + // Handle cases like + // window w1 as (partition by p_mfgr order by p_name + // range between 2 preceding and 2 following), + // w2 as w1 + val windowMapView = baseWindowMap.mapValues { + case WindowSpecReference(name) => + baseWindowMap.get(name) match { + case Some(spec: WindowSpecDefinition) => + spec + case Some(ref) => + throw invalidWindowReferenceError(name, ctx) + case None => + throw cannotResolveWindowReferenceError(name, ctx) + } + case spec: WindowSpecDefinition => spec + } + + // Note that mapValues creates a view instead of materialized map. We force materialization by + // mapping over identity. + WithWindowDefinition(windowMapView.map(identity).toMap, query) + } + + /** + * Add an [[Aggregate]] to a logical plan. + */ + private def withAggregationClause( + ctx: AggregationClauseContext, + selectExpressions: Seq[NamedExpression], + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + if (ctx.groupingExpressionsWithGroupingAnalytics.isEmpty) { + val groupByExpressions = expressionList(ctx.groupingExpressions) + if (ctx.GROUPING != null) { + // GROUP BY ... GROUPING SETS (...) + // `groupByExpressions` can be non-empty for Hive compatibility. It may add extra grouping + // expressions that do not exist in GROUPING SETS (...), and the value is always null. + // For example, `SELECT a, b, c FROM ... GROUP BY a, b, c GROUPING SETS (a, b)`, the output + // of column `c` is always null. + val groupingSets = + ctx.groupingSet.asScala.map(_.expression.asScala.map(e => expression(e)).toSeq) + Aggregate(Seq(GroupingSets(groupingSets.toSeq, groupByExpressions)), + selectExpressions, query) + } else { + // GROUP BY .... (WITH CUBE | WITH ROLLUP)? + val mappedGroupByExpressions = if (ctx.CUBE != null) { + Seq(Cube(groupByExpressions.map(Seq(_)))) + } else if (ctx.ROLLUP != null) { + Seq(Rollup(groupByExpressions.map(Seq(_)))) + } else { + groupByExpressions + } + Aggregate(mappedGroupByExpressions, selectExpressions, query) + } + } else { + val groupByExpressions = + ctx.groupingExpressionsWithGroupingAnalytics.asScala + .map(groupByExpr => { + val groupingAnalytics = groupByExpr.groupingAnalytics + if (groupingAnalytics != null) { + visitGroupingAnalytics(groupingAnalytics) + } else { + expression(groupByExpr.expression) + } + }) + Aggregate(groupByExpressions.toSeq, selectExpressions, query) + } + } + + override def visitGroupingAnalytics( + groupingAnalytics: GroupingAnalyticsContext): BaseGroupingSets = { + val groupingSets = groupingAnalytics.groupingSet.asScala + .map(_.expression.asScala.map(e => expression(e)).toSeq) + if (groupingAnalytics.CUBE != null) { + // CUBE(A, B, (A, B), ()) is not supported. + if (groupingSets.exists(_.isEmpty)) { + throw invalidGroupingSetError("CUBE", groupingAnalytics) + } + Cube(groupingSets.toSeq) + } else if (groupingAnalytics.ROLLUP != null) { + // ROLLUP(A, B, (A, B), ()) is not supported. + if (groupingSets.exists(_.isEmpty)) { + throw invalidGroupingSetError("ROLLUP", groupingAnalytics) + } + Rollup(groupingSets.toSeq) + } else { + assert(groupingAnalytics.GROUPING != null && groupingAnalytics.SETS != null) + val groupingSets = groupingAnalytics.groupingElement.asScala.flatMap { expr => + val groupingAnalytics = expr.groupingAnalytics() + if (groupingAnalytics != null) { + visitGroupingAnalytics(groupingAnalytics).selectedGroupByExprs + } else { + Seq(expr.groupingSet().expression().asScala.map(e => expression(e)).toSeq) + } + } + GroupingSets(groupingSets.toSeq) + } + } + + /** + * Add [[UnresolvedHint]]s to a logical plan. + */ + private def withHints( + ctx: HintContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + var plan = query + ctx.hintStatements.asScala.reverse.foreach { stmt => + plan = UnresolvedHint(stmt.hintName.getText, + stmt.parameters.asScala.map(expression).toSeq, plan) + } + plan + } + + /** + * Add a [[Pivot]] to a logical plan. + */ + private def withPivot( + ctx: PivotClauseContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val aggregates = Option(ctx.aggregates).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) { + UnresolvedAttribute.quoted(ctx.pivotColumn.identifier.getText) + } else { + CreateStruct( + ctx.pivotColumn.identifiers.asScala.map( + identifier => UnresolvedAttribute.quoted(identifier.getText)).toSeq) + } + val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue) + Pivot(None, pivotColumn, pivotValues.toSeq, aggregates, query) + } + + /** + * Create a Pivot column value with or without an alias. + */ + override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else { + e + } + } + + /** + * Add a [[Generate]] (Lateral View) to a logical plan. + */ + private def withGenerate( + query: LogicalPlan, + ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) { + val expressions = expressionList(ctx.expression) + Generate( + UnresolvedGenerator(visitFunctionName(ctx.qualifiedName), expressions), + unrequiredChildIndex = Nil, + outer = ctx.OUTER != null, + // scalastyle:off caselocale + Some(ctx.tblName.getText.toLowerCase), + // scalastyle:on caselocale + ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.quoted).toSeq, + query) + } + + /** + * Create a single relation referenced in a FROM clause. This method is used when a part of the + * join condition is nested, for example: + * {{{ + * select * from t1 join (t2 cross join t3) on col1 = col2 + * }}} + */ + override def visitRelation(ctx: RelationContext): LogicalPlan = withOrigin(ctx) { + withJoinRelations(plan(ctx.relationPrimary), ctx) + } + + /** + * Join one more [[LogicalPlan]]s to the current logical plan. + */ + private def withJoinRelations(base: LogicalPlan, ctx: RelationContext): LogicalPlan = { + ctx.joinRelation.asScala.foldLeft(base) { (left, join) => + withOrigin(join) { + val baseJoinType = join.joinType match { + case null => Inner + case jt if jt.CROSS != null => Cross + case jt if jt.FULL != null => FullOuter + case jt if jt.SEMI != null => LeftSemi + case jt if jt.ANTI != null => LeftAnti + case jt if jt.LEFT != null => LeftOuter + case jt if jt.RIGHT != null => RightOuter + case _ => Inner + } + + if (join.LATERAL != null && !join.right.isInstanceOf[AliasedQueryContext]) { + throw invalidLateralJoinRelationError(join.right) + } + + // Resolve the join type and join condition + val (joinType, condition) = Option(join.joinCriteria) match { + case Some(c) if c.USING != null => + if (join.LATERAL != null) { + throw lateralJoinWithUsingJoinUnsupportedError(ctx) + } + (UsingJoin(baseJoinType, visitIdentifierList(c.identifierList)), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case Some(c) => + throw joinCriteriaUnimplementedError(c, ctx) + case None if join.NATURAL != null => + if (join.LATERAL != null) { + throw lateralJoinWithNaturalJoinUnsupportedError(ctx) + } + if (baseJoinType == Cross) { + throw naturalCrossJoinUnsupportedError(ctx) + } + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + if (join.LATERAL != null) { + if (!Seq(Inner, Cross, LeftOuter).contains(joinType)) { + throw unsupportedLateralJoinTypeError(ctx, joinType.toString) + } + LateralJoin(left, LateralSubquery(plan(join.right)), joinType, condition) + } else { + Join(left, plan(join.right), joinType, condition, JoinHint.NONE) + } + } + } + } + + /** + * Add a [[Sample]] to a logical plan. + * + * This currently supports the following sampling methods: + * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. + * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages + * are defined as a number between 0 and 100. + * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction. + */ + private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Create a sampled plan if we need one. + def sample(fraction: Double): Sample = { + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + val eps = RandomSampler.roundingEpsilon + validate(fraction >= 0.0 - eps && fraction <= 1.0 + eps, + s"Sampling fraction ($fraction) must be on interval [0, 1]", + ctx) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query) + } + + if (ctx.sampleMethod() == null) { + throw emptyInputForTableSampleError(ctx) + } + + ctx.sampleMethod() match { + case ctx: SampleByRowsContext => + Limit(expression(ctx.expression), query) + + case ctx: SampleByPercentileContext => + val fraction = ctx.percentage.getText.toDouble + val sign = if (ctx.negativeSign == null) 1 else -1 + sample(sign * fraction / 100.0d) + + case ctx: SampleByBytesContext => + val bytesStr = ctx.bytes.getText + if (bytesStr.matches("[0-9]+[bBkKmMgG]")) { + throw tableSampleByBytesUnsupportedError("byteLengthLiteral", ctx) + } else { + throw invalidByteLengthLiteralError(bytesStr, ctx) + } + + case ctx: SampleByBucketContext if ctx.ON() != null => + if (ctx.identifier != null) { + throw tableSampleByBytesUnsupportedError( + "BUCKET x OUT OF y ON colname", ctx) + } else { + throw tableSampleByBytesUnsupportedError( + "BUCKET x OUT OF y ON function", ctx) + } + + case ctx: SampleByBucketContext => + sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) + } + } + + /** + * Create a logical plan for a sub-query. + */ + override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.query) + } + + /** + * Create an un-aliased table reference. This is typically used for top-level table references, + * for example: + * {{{ + * INSERT INTO db.tbl2 + * TABLE db.tbl1 + * }}} + */ + override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) { + UnresolvedRelation(visitMultipartIdentifier(ctx.multipartIdentifier)) + } + + /** + * Create a table-valued function call with arguments, e.g. range(1000) + */ + override def visitTableValuedFunction(ctx: TableValuedFunctionContext) + : LogicalPlan = withOrigin(ctx) { + val func = ctx.functionTable + val aliases = if (func.tableAlias.identifierList != null) { + visitIdentifierList(func.tableAlias.identifierList) + } else { + Seq.empty + } + val name = getFunctionIdentifier(func.functionName) + if (name.database.nonEmpty) { + operationNotAllowed(s"table valued function cannot specify database name: $name", ctx) + } + + val tvf = UnresolvedTableValuedFunction( + name, func.expression.asScala.map(expression).toSeq, aliases) + tvf.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan) + } + + /** + * Create an inline table (a virtual table in Hive parlance). + */ + override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { + // Get the backing expressions. + val rows = ctx.expression.asScala.map { e => + expression(e) match { + // inline table comes in two styles: + // style 1: values (1), (2), (3) -- multiple columns are supported + // style 2: values 1, 2, 3 -- only a single column is supported here + case struct: CreateNamedStruct => struct.valExprs // style 1 + case child => Seq(child) // style 2 + } + } + + val aliases = if (ctx.tableAlias.identifierList != null) { + visitIdentifierList(ctx.tableAlias.identifierList) + } else { + Seq.tabulate(rows.head.size)(i => s"col${i + 1}") + } + + val table = UnresolvedInlineTable(aliases, rows.toSeq) + table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a join relation. This is practically the same as + * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. We could add alias names for output columns, for example: + * {{{ + * SELECT a, b, c, d FROM (src1 s1 INNER JOIN src2 s2 ON s1.id = s2.id) dst(a, b, c, d) + * }}} + */ + override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) { + val relation = plan(ctx.relation).optionalMap(ctx.sample)(withSample) + mayApplyAliasPlan(ctx.tableAlias, relation) + } + + /** + * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as + * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. We could add alias names for output columns, for example: + * {{{ + * SELECT col1, col2 FROM testData AS t(col1, col2) + * }}} + */ + override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { + val relation = plan(ctx.query).optionalMap(ctx.sample)(withSample) + if (ctx.tableAlias.strictIdentifier == null) { + // For un-aliased subqueries, use a default alias name that is not likely to conflict with + // normal subquery names, so that parent operators can only access the columns in subquery by + // unqualified names. Users can still use this special qualifier to access columns if they + // know it, but that's not recommended. + SubqueryAlias("__auto_generated_subquery_name", relation) + } else { + mayApplyAliasPlan(ctx.tableAlias, relation) + } + } + + /** + * Create an alias ([[SubqueryAlias]]) for a [[LogicalPlan]]. + */ + private def aliasPlan(alias: ParserRuleContext, plan: LogicalPlan): LogicalPlan = { + SubqueryAlias(alias.getText, plan) + } + + /** + * If aliases specified in a FROM clause, create a subquery alias ([[SubqueryAlias]]) and + * column aliases for a [[LogicalPlan]]. + */ + private def mayApplyAliasPlan(tableAlias: TableAliasContext, plan: LogicalPlan): LogicalPlan = { + if (tableAlias.strictIdentifier != null) { + val alias = tableAlias.strictIdentifier.getText + if (tableAlias.identifierList != null) { + val columnNames = visitIdentifierList(tableAlias.identifierList) + SubqueryAlias(alias, UnresolvedSubqueryColumnAliases(columnNames, plan)) + } else { + SubqueryAlias(alias, plan) + } + } else { + plan + } + } + + /** + * Create a Sequence of Strings for a parenthesis enclosed alias list. + */ + override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) { + visitIdentifierSeq(ctx.identifierSeq) + } + + /** + * Create a Sequence of Strings for an identifier list. + */ + override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) { + ctx.ident.asScala.map(_.getText).toSeq + } + + /* ******************************************************************************************** + * Table Identifier parsing + * ******************************************************************************************** */ + /** + * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern. + */ + override def visitTableIdentifier( + ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) { + TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) + } + + /** + * Create a [[FunctionIdentifier]] from a 'functionName' or 'databaseName'.'functionName' pattern. + */ + override def visitFunctionIdentifier( + ctx: FunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) { + FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText)) + } + + /** + * Create a multi-part identifier. + */ + override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] = + withOrigin(ctx) { + ctx.parts.asScala.map(_.getText).toSeq + } + + /* ******************************************************************************************** + * Expression parsing + * ******************************************************************************************** */ + /** + * Create an expression from the given context. This method just passes the context on to the + * visitor and only takes care of typing (We assume that the visitor returns an Expression here). + */ + protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx) + + /** + * Create sequence of expressions from the given sequence of contexts. + */ + private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = { + trees.asScala.map(expression).toSeq + } + + /** + * Create a star (i.e. all) expression; this selects all elements (in the specified object). + * Both un-targeted (global) and targeted aliases are supported. + */ + override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) { + UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText).toSeq)) + } + + /** + * Create an aliased expression if an alias is specified. Both single and multi-aliases are + * supported. + */ + override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.name != null) { + Alias(e, ctx.name.getText)() + } else if (ctx.identifierList != null) { + MultiAlias(e, visitIdentifierList(ctx.identifierList)) + } else { + e + } + } + + /** + * Combine a number of boolean expressions into a balanced expression tree. These expressions are + * either combined by a logical [[And]] or a logical [[Or]]. + * + * A balanced binary tree is created because regular left recursive trees cause considerable + * performance degradations and can cause stack overflows. + */ + override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) { + val expressionType = ctx.operator.getType + val expressionCombiner = expressionType match { + case HoodieSqlBaseParser.AND => And.apply _ + case HoodieSqlBaseParser.OR => Or.apply _ + } + + // Collect all similar left hand contexts. + val contexts = ArrayBuffer(ctx.right) + var current = ctx.left + def collectContexts: Boolean = current match { + case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType => + contexts += lbc.right + current = lbc.left + true + case _ => + contexts += current + false + } + while (collectContexts) { + // No body - all updates take place in the collectContexts. + } + + // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them + // into expressions. + val expressions = contexts.reverseMap(expression) + + // Create a balanced tree. + def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match { + case 0 => + expressions(low) + case 1 => + expressionCombiner(expressions(low), expressions(high)) + case x => + val mid = low + x / 2 + expressionCombiner( + reduceToExpressionTree(low, mid), + reduceToExpressionTree(mid + 1, high)) + } + reduceToExpressionTree(0, expressions.size - 1) + } + + /** + * Invert a boolean expression. + */ + override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) { + Not(expression(ctx.booleanExpression())) + } + + /** + * Create a filtering correlated sub-query (EXISTS). + */ + override def visitExists(ctx: ExistsContext): Expression = { + Exists(plan(ctx.query)) + } + + /** + * Create a comparison expression. This compares two expressions. The following comparison + * operators are supported: + * - Equal: '=' or '==' + * - Null-safe Equal: '<=>' + * - Not Equal: '<>' or '!=' + * - Less than: '<' + * - Less then or Equal: '<=' + * - Greater than: '>' + * - Greater then or Equal: '>=' + */ + override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode] + operator.getSymbol.getType match { + case HoodieSqlBaseParser.EQ => + EqualTo(left, right) + case HoodieSqlBaseParser.NSEQ => + EqualNullSafe(left, right) + case HoodieSqlBaseParser.NEQ | HoodieSqlBaseParser.NEQJ => + Not(EqualTo(left, right)) + case HoodieSqlBaseParser.LT => + LessThan(left, right) + case HoodieSqlBaseParser.LTE => + LessThanOrEqual(left, right) + case HoodieSqlBaseParser.GT => + GreaterThan(left, right) + case HoodieSqlBaseParser.GTE => + GreaterThanOrEqual(left, right) + } + } + + /** + * Create a predicated expression. A predicated expression is a normal expression with a + * predicate attached to it, for example: + * {{{ + * a + 1 IS NULL + * }}} + */ + override def visitPredicated(ctx: PredicatedContext): Expression = withOrigin(ctx) { + val e = expression(ctx.valueExpression) + if (ctx.predicate != null) { + withPredicate(e, ctx.predicate) + } else { + e + } + } + + /** + * Add a predicate to the given expression. Supported expressions are: + * - (NOT) BETWEEN + * - (NOT) IN + * - (NOT) LIKE (ANY | SOME | ALL) + * - (NOT) RLIKE + * - IS (NOT) NULL. + * - IS (NOT) (TRUE | FALSE | UNKNOWN) + * - IS (NOT) DISTINCT FROM + */ + private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) { + // Invert a predicate if it has a valid NOT clause. + def invertIfNotDefined(e: Expression): Expression = ctx.NOT match { + case null => e + case not => Not(e) + } + + def getValueExpressions(e: Expression): Seq[Expression] = e match { + case c: CreateNamedStruct => c.valExprs + case other => Seq(other) + } + + // Create the predicate. + ctx.kind.getType match { + case HoodieSqlBaseParser.BETWEEN => + // BETWEEN is translated to lower <= e && e <= upper + invertIfNotDefined(And( + GreaterThanOrEqual(e, expression(ctx.lower)), + LessThanOrEqual(e, expression(ctx.upper)))) + case HoodieSqlBaseParser.IN if ctx.query != null => + invertIfNotDefined(InSubquery(getValueExpressions(e), ListQuery(plan(ctx.query)))) + case HoodieSqlBaseParser.IN => + invertIfNotDefined(In(e, ctx.expression.asScala.map(expression).toSeq)) + case HoodieSqlBaseParser.LIKE => + Option(ctx.quantifier).map(_.getType) match { + case Some(HoodieSqlBaseParser.ANY) | Some(HoodieSqlBaseParser.SOME) => + validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx) + val expressions = expressionList(ctx.expression) + if (expressions.forall(_.foldable) && expressions.forall(_.dataType == StringType)) { + // If there are many pattern expressions, will throw StackOverflowError. + // So we use LikeAny or NotLikeAny instead. + val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String]) + ctx.NOT match { + case null => LikeAny(e, patterns) + case _ => NotLikeAny(e, patterns) + } + } else { + ctx.expression.asScala.map(expression) + .map(p => invertIfNotDefined(new Like(e, p))).toSeq.reduceLeft(Or) + } + case Some(HoodieSqlBaseParser.ALL) => + validate(!ctx.expression.isEmpty, "Expected something between '(' and ')'.", ctx) + val expressions = expressionList(ctx.expression) + if (expressions.forall(_.foldable) && expressions.forall(_.dataType == StringType)) { + // If there are many pattern expressions, will throw StackOverflowError. + // So we use LikeAll or NotLikeAll instead. + val patterns = expressions.map(_.eval(EmptyRow).asInstanceOf[UTF8String]) + ctx.NOT match { + case null => LikeAll(e, patterns) + case _ => NotLikeAll(e, patterns) + } + } else { + ctx.expression.asScala.map(expression) + .map(p => invertIfNotDefined(new Like(e, p))).toSeq.reduceLeft(And) + } + case _ => + val escapeChar = Option(ctx.escapeChar).map(string).map { str => + if (str.length != 1) { + throw invalidEscapeStringError(ctx) + } + str.charAt(0) + }.getOrElse('\\') + invertIfNotDefined(Like(e, expression(ctx.pattern), escapeChar)) + } + case HoodieSqlBaseParser.RLIKE => + invertIfNotDefined(RLike(e, expression(ctx.pattern))) + case HoodieSqlBaseParser.NULL if ctx.NOT != null => + IsNotNull(e) + case HoodieSqlBaseParser.NULL => + IsNull(e) + case HoodieSqlBaseParser.TRUE => ctx.NOT match { + case null => EqualNullSafe(e, Literal(true)) + case _ => Not(EqualNullSafe(e, Literal(true))) + } + case HoodieSqlBaseParser.FALSE => ctx.NOT match { + case null => EqualNullSafe(e, Literal(false)) + case _ => Not(EqualNullSafe(e, Literal(false))) + } + case HoodieSqlBaseParser.UNKNOWN => ctx.NOT match { + case null => IsUnknown(e) + case _ => IsNotUnknown(e) + } + case HoodieSqlBaseParser.DISTINCT if ctx.NOT != null => + EqualNullSafe(e, expression(ctx.right)) + case HoodieSqlBaseParser.DISTINCT => + Not(EqualNullSafe(e, expression(ctx.right))) + } + } + + /** + * Create a binary arithmetic expression. The following arithmetic operators are supported: + * - Multiplication: '*' + * - Division: '/' + * - Hive Long Division: 'DIV' + * - Modulo: '%' + * - Addition: '+' + * - Subtraction: '-' + * - Binary AND: '&' + * - Binary XOR + * - Binary OR: '|' + */ + override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + ctx.operator.getType match { + case HoodieSqlBaseParser.ASTERISK => + Multiply(left, right) + case HoodieSqlBaseParser.SLASH => + Divide(left, right) + case HoodieSqlBaseParser.PERCENT => + Remainder(left, right) + case HoodieSqlBaseParser.DIV => + IntegralDivide(left, right) + case HoodieSqlBaseParser.PLUS => + Add(left, right) + case HoodieSqlBaseParser.MINUS => + Subtract(left, right) + case HoodieSqlBaseParser.CONCAT_PIPE => + Concat(left :: right :: Nil) + case HoodieSqlBaseParser.AMPERSAND => + BitwiseAnd(left, right) + case HoodieSqlBaseParser.HAT => + BitwiseXor(left, right) + case HoodieSqlBaseParser.PIPE => + BitwiseOr(left, right) + } + } + + /** + * Create a unary arithmetic expression. The following arithmetic operators are supported: + * - Plus: '+' + * - Minus: '-' + * - Bitwise Not: '~' + */ + override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) { + val value = expression(ctx.valueExpression) + ctx.operator.getType match { + case HoodieSqlBaseParser.PLUS => + UnaryPositive(value) + case HoodieSqlBaseParser.MINUS => + UnaryMinus(value) + case HoodieSqlBaseParser.TILDE => + BitwiseNot(value) + } + } + + override def visitCurrentLike(ctx: CurrentLikeContext): Expression = withOrigin(ctx) { + if (conf.ansiEnabled) { + ctx.name.getType match { + case HoodieSqlBaseParser.CURRENT_DATE => + CurrentDate() + case HoodieSqlBaseParser.CURRENT_TIMESTAMP => + CurrentTimestamp() + case HoodieSqlBaseParser.CURRENT_USER => + CurrentUser() + } + } else { + // If the parser is not in ansi mode, we should return `UnresolvedAttribute`, in case there + // are columns named `CURRENT_DATE` or `CURRENT_TIMESTAMP`. + UnresolvedAttribute.quoted(ctx.name.getText) + } + } + + /** + * Create a [[Cast]] expression. + */ + override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { + val rawDataType = typedVisit[DataType](ctx.dataType()) + val dataType = CharVarcharUtils.replaceCharVarcharWithStringForCast(rawDataType) + val cast = ctx.name.getType match { + case HoodieSqlBaseParser.CAST => + Cast(expression(ctx.expression), dataType) + + case HoodieSqlBaseParser.TRY_CAST => + TryCast(expression(ctx.expression), dataType) + } + cast.setTagValue(Cast.USER_SPECIFIED_CAST, true) + cast + } + + /** + * Create a [[CreateStruct]] expression. + */ + override def visitStruct(ctx: StructContext): Expression = withOrigin(ctx) { + CreateStruct.create(ctx.argument.asScala.map(expression).toSeq) + } + + /** + * Create a [[First]] expression. + */ + override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + First(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression() + } + + /** + * Create a [[Last]] expression. + */ + override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) { + val ignoreNullsExpr = ctx.IGNORE != null + Last(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression() + } + + /** + * Create a Position expression. + */ + override def visitPosition(ctx: PositionContext): Expression = withOrigin(ctx) { + new StringLocate(expression(ctx.substr), expression(ctx.str)) + } + + /** + * Create a Extract expression. + */ + override def visitExtract(ctx: ExtractContext): Expression = withOrigin(ctx) { + val arguments = Seq(Literal(ctx.field.getText), expression(ctx.source)) + UnresolvedFunction("extract", arguments, isDistinct = false) + } + + /** + * Create a Substring/Substr expression. + */ + override def visitSubstring(ctx: SubstringContext): Expression = withOrigin(ctx) { + if (ctx.len != null) { + Substring(expression(ctx.str), expression(ctx.pos), expression(ctx.len)) + } else { + new Substring(expression(ctx.str), expression(ctx.pos)) + } + } + + /** + * Create a Trim expression. + */ + override def visitTrim(ctx: TrimContext): Expression = withOrigin(ctx) { + val srcStr = expression(ctx.srcStr) + val trimStr = Option(ctx.trimStr).map(expression) + Option(ctx.trimOption).map(_.getType).getOrElse(HoodieSqlBaseParser.BOTH) match { + case HoodieSqlBaseParser.BOTH => + StringTrim(srcStr, trimStr) + case HoodieSqlBaseParser.LEADING => + StringTrimLeft(srcStr, trimStr) + case HoodieSqlBaseParser.TRAILING => + StringTrimRight(srcStr, trimStr) + case other => + throw trimOptionUnsupportedError(other, ctx) + } + } + + /** + * Create a Overlay expression. + */ + override def visitOverlay(ctx: OverlayContext): Expression = withOrigin(ctx) { + val input = expression(ctx.input) + val replace = expression(ctx.replace) + val position = expression(ctx.position) + val lengthOpt = Option(ctx.length).map(expression) + lengthOpt match { + case Some(length) => Overlay(input, replace, position, length) + case None => new Overlay(input, replace, position) + } + } + + /** + * Create a (windowed) Function expression. + */ + override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { + // Create the function call. + val name = ctx.functionName.getText + val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) + // Call `toSeq`, otherwise `ctx.argument.asScala.map(expression)` is `Buffer` in Scala 2.13 + val arguments = ctx.argument.asScala.map(expression).toSeq match { + case Seq(UnresolvedStar(None)) + if name.toLowerCase(Locale.ROOT) == "count" && !isDistinct => + // Transform COUNT(*) into COUNT(1). + Seq(Literal(1)) + case expressions => + expressions + } + val filter = Option(ctx.where).map(expression(_)) + val ignoreNulls = + Option(ctx.nullsOption).map(_.getType == HoodieSqlBaseParser.IGNORE).getOrElse(false) + val function = UnresolvedFunction( + getFunctionMultiparts(ctx.functionName), arguments, isDistinct, filter, ignoreNulls) + + // Check if the function is evaluated in a windowed context. + ctx.windowSpec match { + case spec: WindowRefContext => + UnresolvedWindowExpression(function, visitWindowRef(spec)) + case spec: WindowDefContext => + WindowExpression(function, visitWindowDef(spec)) + case _ => function + } + } + + /** + * Create a function database (optional) and name pair. + */ + protected def visitFunctionName(ctx: QualifiedNameContext): FunctionIdentifier = { + visitFunctionName(ctx, ctx.identifier().asScala.map(_.getText).toSeq) + } + + /** + * Create a function database (optional) and name pair. + */ + private def visitFunctionName(ctx: ParserRuleContext, texts: Seq[String]): FunctionIdentifier = { + texts match { + case Seq(db, fn) => FunctionIdentifier(fn, Option(db)) + case Seq(fn) => FunctionIdentifier(fn, None) + case other => + throw functionNameUnsupportedError(texts.mkString("."), ctx) + } + } + + /** + * Get a function identifier consist by database (optional) and name. + */ + protected def getFunctionIdentifier(ctx: FunctionNameContext): FunctionIdentifier = { + if (ctx.qualifiedName != null) { + visitFunctionName(ctx.qualifiedName) + } else { + FunctionIdentifier(ctx.getText, None) + } + } + + protected def getFunctionMultiparts(ctx: FunctionNameContext): Seq[String] = { + if (ctx.qualifiedName != null) { + ctx.qualifiedName().identifier().asScala.map(_.getText).toSeq + } else { + Seq(ctx.getText) + } + } + + /** + * Create an [[LambdaFunction]]. + */ + override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) { + val arguments = ctx.identifier().asScala.map { name => + UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts) + } + val function = expression(ctx.expression).transformUp { + case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts) + } + LambdaFunction(function, arguments.toSeq) + } + + /** + * Create a reference to a window frame, i.e. [[WindowSpecReference]]. + */ + override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) { + WindowSpecReference(ctx.name.getText) + } + + /** + * Create a window definition, i.e. [[WindowSpecDefinition]]. + */ + override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) { + // CLUSTER BY ... | PARTITION BY ... ORDER BY ... + val partition = ctx.partition.asScala.map(expression) + val order = ctx.sortItem.asScala.map(visitSortItem) + + // RANGE/ROWS BETWEEN ... + val frameSpecOption = Option(ctx.windowFrame).map { frame => + val frameType = frame.frameType.getType match { + case HoodieSqlBaseParser.RANGE => RangeFrame + case HoodieSqlBaseParser.ROWS => RowFrame + } + + SpecifiedWindowFrame( + frameType, + visitFrameBound(frame.start), + Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow)) + } + + WindowSpecDefinition( + partition.toSeq, + order.toSeq, + frameSpecOption.getOrElse(UnspecifiedFrame)) + } + + /** + * Create or resolve a frame boundary expressions. + */ + override def visitFrameBound(ctx: FrameBoundContext): Expression = withOrigin(ctx) { + def value: Expression = { + val e = expression(ctx.expression) + validate(e.resolved && e.foldable, "Frame bound value must be a literal.", ctx) + e + } + + ctx.boundType.getType match { + case HoodieSqlBaseParser.PRECEDING if ctx.UNBOUNDED != null => + UnboundedPreceding + case HoodieSqlBaseParser.PRECEDING => + UnaryMinus(value) + case HoodieSqlBaseParser.CURRENT => + CurrentRow + case HoodieSqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null => + UnboundedFollowing + case HoodieSqlBaseParser.FOLLOWING => + value + } + } + + /** + * Create a [[CreateStruct]] expression. + */ + override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { + CreateStruct(ctx.namedExpression().asScala.map(expression).toSeq) + } + + /** + * Create a [[ScalarSubquery]] expression. + */ + override def visitSubqueryExpression( + ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) { + ScalarSubquery(plan(ctx.query)) + } + + /** + * Create a value based [[CaseWhen]] expression. This has the following SQL form: + * {{{ + * CASE [expression] + * WHEN [value] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + */ + override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { + val e = expression(ctx.value) + val branches = ctx.whenClause.asScala.map { wCtx => + (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) + } + CaseWhen(branches.toSeq, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax: + * {{{ + * CASE + * WHEN [predicate] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + * + * @param ctx the parse tree + * */ + override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) { + val branches = ctx.whenClause.asScala.map { wCtx => + (expression(wCtx.condition), expression(wCtx.result)) + } + CaseWhen(branches.toSeq, Option(ctx.elseExpression).map(expression)) + } + + /** + * Currently only regex in expressions of SELECT statements are supported; in other + * places, e.g., where `(a)?+.+` = 2, regex are not meaningful. + */ + private def canApplyRegex(ctx: ParserRuleContext): Boolean = withOrigin(ctx) { + var parent = ctx.getParent + var rtn = false + while (parent != null) { + if (parent.isInstanceOf[NamedExpressionContext]) { + rtn = true + } + parent = parent.getParent + } + rtn + } + + /** + * Create a dereference expression. The return type depends on the type of the parent. + * If the parent is an [[UnresolvedAttribute]], it can be a [[UnresolvedAttribute]] or + * a [[UnresolvedRegex]] for regex quoted in ``; if the parent is some other expression, + * it can be [[UnresolvedExtractValue]]. + */ + override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) { + val attr = ctx.fieldName.getText + expression(ctx.base) match { + case unresolved_attr @ UnresolvedAttribute(nameParts) => + ctx.fieldName.getStart.getText match { + case escapedIdentifier(columnNameRegex) + if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) => + UnresolvedRegex(columnNameRegex, Some(unresolved_attr.name), + conf.caseSensitiveAnalysis) + case _ => + UnresolvedAttribute(nameParts :+ attr) + } + case e => + UnresolvedExtractValue(e, Literal(attr)) + } + } + + /** + * Create an [[UnresolvedAttribute]] expression or a [[UnresolvedRegex]] if it is a regex + * quoted in `` + */ + override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) { + ctx.getStart.getText match { + case escapedIdentifier(columnNameRegex) + if conf.supportQuotedRegexColumnName && canApplyRegex(ctx) => + UnresolvedRegex(columnNameRegex, None, conf.caseSensitiveAnalysis) + case _ => + UnresolvedAttribute.quoted(ctx.getText) + } + + } + + /** + * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array. + */ + override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) { + UnresolvedExtractValue(expression(ctx.value), expression(ctx.index)) + } + + /** + * Create an expression for an expression between parentheses. This is need because the ANTLR + * visitor cannot automatically convert the nested context into an expression. + */ + override def visitParenthesizedExpression( + ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) { + expression(ctx.expression) + } + + /** + * Create a [[SortOrder]] expression. + */ + override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) { + val direction = if (ctx.DESC != null) { + Descending + } else { + Ascending + } + val nullOrdering = if (ctx.FIRST != null) { + NullsFirst + } else if (ctx.LAST != null) { + NullsLast + } else { + direction.defaultNullOrdering + } + SortOrder(expression(ctx.expression), direction, nullOrdering, Seq.empty) + } + + /** + * Create a typed Literal expression. A typed literal has the following SQL syntax: + * {{{ + * [TYPE] '[VALUE]' + * }}} + * Currently Date, Timestamp, Interval and Binary typed literals are supported. + */ + override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { + val value = string(ctx.STRING) + val valueType = ctx.identifier.getText.toUpperCase(Locale.ROOT) + + def toLiteral[T](f: UTF8String => Option[T], t: DataType): Literal = { + f(UTF8String.fromString(value)).map(Literal(_, t)).getOrElse { + throw cannotParseValueTypeError(valueType, value, ctx) + } + } + + def constructTimestampLTZLiteral(value: String): Literal = { + val zoneId = getZoneId(conf.sessionLocalTimeZone) + val specialTs = convertSpecialTimestamp(value, zoneId).map(Literal(_, TimestampType)) + specialTs.getOrElse(toLiteral(stringToTimestamp(_, zoneId), TimestampType)) + } + + try { + valueType match { + case "DATE" => + val zoneId = getZoneId(conf.sessionLocalTimeZone) + val specialDate = convertSpecialDate(value, zoneId).map(Literal(_, DateType)) + specialDate.getOrElse(toLiteral(stringToDate, DateType)) + // SPARK-36227: Remove TimestampNTZ type support in Spark 3.2 with minimal code changes. + case "TIMESTAMP_NTZ" if isTesting => + convertSpecialTimestampNTZ(value, getZoneId(conf.sessionLocalTimeZone)) + .map(Literal(_, TimestampNTZType)) + .getOrElse(toLiteral(stringToTimestampWithoutTimeZone, TimestampNTZType)) + case "TIMESTAMP_LTZ" if isTesting => + constructTimestampLTZLiteral(value) + case "TIMESTAMP" => + SQLConf.get.timestampType match { + case TimestampNTZType => + convertSpecialTimestampNTZ(value, getZoneId(conf.sessionLocalTimeZone)) + .map(Literal(_, TimestampNTZType)) + .getOrElse { + val containsTimeZonePart = + DateTimeUtils.parseTimestampString(UTF8String.fromString(value))._2.isDefined + // If the input string contains time zone part, return a timestamp with local time + // zone literal. + if (containsTimeZonePart) { + constructTimestampLTZLiteral(value) + } else { + toLiteral(stringToTimestampWithoutTimeZone, TimestampNTZType) + } + } + + case TimestampType => + constructTimestampLTZLiteral(value) + } + + case "INTERVAL" => + val interval = try { + IntervalUtils.stringToInterval(UTF8String.fromString(value)) + } catch { + case e: IllegalArgumentException => + val ex = cannotParseIntervalValueError(value, ctx) + ex.setStackTrace(e.getStackTrace) + throw ex + } + if (!conf.legacyIntervalEnabled) { + val units = value + .split("\\s") + .map(_.toLowerCase(Locale.ROOT).stripSuffix("s")) + .filter(s => s != "interval" && s.matches("[a-z]+")) + constructMultiUnitsIntervalLiteral(ctx, interval, units) + } else { + Literal(interval, CalendarIntervalType) + } + case "X" => + val padding = if (value.length % 2 != 0) "0" else "" + Literal(DatatypeConverter.parseHexBinary(padding + value)) + case other => + throw literalValueTypeUnsupportedError(other, ctx) + } + } catch { + case e: IllegalArgumentException => + throw parsingValueTypeError(e, valueType, ctx) + } + } + + /** + * Create a NULL literal expression. + */ + override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) { + Literal(null) + } + + /** + * Create a Boolean literal expression. + */ + override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) { + if (ctx.getText.toBoolean) { + Literal.TrueLiteral + } else { + Literal.FalseLiteral + } + } + + /** + * Create an integral literal expression. The code selects the most narrow integral type + * possible, either a BigDecimal, a Long or an Integer is returned. + */ + override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) { + BigDecimal(ctx.getText) match { + case v if v.isValidInt => + Literal(v.intValue) + case v if v.isValidLong => + Literal(v.longValue) + case v => Literal(v.underlying()) + } + } + + /** + * Create a decimal literal for a regular decimal number. + */ + override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(BigDecimal(ctx.getText).underlying()) + } + + /** + * Create a decimal literal for a regular decimal number or a scientific decimal number. + */ + override def visitLegacyDecimalLiteral( + ctx: LegacyDecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(BigDecimal(ctx.getText).underlying()) + } + + /** + * Create a double literal for number with an exponent, e.g. 1E-30 + */ + override def visitExponentLiteral(ctx: ExponentLiteralContext): Literal = { + numericLiteral(ctx, ctx.getText, /* exponent values don't have a suffix */ + Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble) + } + + /** Create a numeric literal expression. */ + private def numericLiteral( + ctx: NumberContext, + rawStrippedQualifier: String, + minValue: BigDecimal, + maxValue: BigDecimal, + typeName: String)(converter: String => Any): Literal = withOrigin(ctx) { + try { + val rawBigDecimal = BigDecimal(rawStrippedQualifier) + if (rawBigDecimal < minValue || rawBigDecimal > maxValue) { + throw invalidNumericLiteralRangeError( + rawStrippedQualifier, minValue, maxValue, typeName, ctx) + } + Literal(converter(rawStrippedQualifier)) + } catch { + case e: NumberFormatException => + throw new ParseException(e.getMessage, ctx) + } + } + + /** + * Create a Byte Literal expression. + */ + override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Byte.MinValue, Byte.MaxValue, ByteType.simpleString)(_.toByte) + } + + /** + * Create a Short Literal expression. + */ + override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Short.MinValue, Short.MaxValue, ShortType.simpleString)(_.toShort) + } + + /** + * Create a Long Literal expression. + */ + override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Long.MinValue, Long.MaxValue, LongType.simpleString)(_.toLong) + } + + /** + * Create a Float Literal expression. + */ + override def visitFloatLiteral(ctx: FloatLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Float.MinValue, Float.MaxValue, FloatType.simpleString)(_.toFloat) + } + + /** + * Create a Double Literal expression. + */ + override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = { + val rawStrippedQualifier = ctx.getText.substring(0, ctx.getText.length - 1) + numericLiteral(ctx, rawStrippedQualifier, + Double.MinValue, Double.MaxValue, DoubleType.simpleString)(_.toDouble) + } + + /** + * Create a BigDecimal Literal expression. + */ + override def visitBigDecimalLiteral(ctx: BigDecimalLiteralContext): Literal = { + val raw = ctx.getText.substring(0, ctx.getText.length - 2) + try { + Literal(BigDecimal(raw).underlying()) + } catch { + case e: AnalysisException => + throw new ParseException(e.message, ctx) + } + } + + /** + * Create a String literal expression. + */ + override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { + Literal(createString(ctx)) + } + + /** + * Create a String from a string literal context. This supports multiple consecutive string + * literals, these are concatenated, for example this expression "'hello' 'world'" will be + * converted into "helloworld". + * + * Special characters can be escaped by using Hive/C-style escaping. + */ + private def createString(ctx: StringLiteralContext): String = { + if (conf.escapedStringLiterals) { + ctx.STRING().asScala.map(stringWithoutUnescape).mkString + } else { + ctx.STRING().asScala.map(string).mkString + } + } + + /** + * Create an [[UnresolvedRelation]] from a multi-part identifier context. + */ + private def createUnresolvedRelation( + ctx: MultipartIdentifierContext): UnresolvedRelation = withOrigin(ctx) { + UnresolvedRelation(visitMultipartIdentifier(ctx)) + } + + /** + * Create an [[UnresolvedTable]] from a multi-part identifier context. + */ + private def createUnresolvedTable( + ctx: MultipartIdentifierContext, + commandName: String, + relationTypeMismatchHint: Option[String] = None): UnresolvedTable = withOrigin(ctx) { + UnresolvedTable(visitMultipartIdentifier(ctx), commandName, relationTypeMismatchHint) + } + + /** + * Create an [[UnresolvedView]] from a multi-part identifier context. + */ + private def createUnresolvedView( + ctx: MultipartIdentifierContext, + commandName: String, + allowTemp: Boolean = true, + relationTypeMismatchHint: Option[String] = None): UnresolvedView = withOrigin(ctx) { + UnresolvedView(visitMultipartIdentifier(ctx), commandName, allowTemp, relationTypeMismatchHint) + } + + /** + * Create an [[UnresolvedTableOrView]] from a multi-part identifier context. + */ + private def createUnresolvedTableOrView( + ctx: MultipartIdentifierContext, + commandName: String, + allowTempView: Boolean = true): UnresolvedTableOrView = withOrigin(ctx) { + UnresolvedTableOrView(visitMultipartIdentifier(ctx), commandName, allowTempView) + } + + /** + * Construct an [[Literal]] from [[CalendarInterval]] and + * units represented as a [[Seq]] of [[String]]. + */ + private def constructMultiUnitsIntervalLiteral( + ctx: ParserRuleContext, + calendarInterval: CalendarInterval, + units: Seq[String]): Literal = { + var yearMonthFields = Set.empty[Byte] + var dayTimeFields = Set.empty[Byte] + for (unit <- units) { + if (YearMonthIntervalType.stringToField.contains(unit)) { + yearMonthFields += YearMonthIntervalType.stringToField(unit) + } else if (DayTimeIntervalType.stringToField.contains(unit)) { + dayTimeFields += DayTimeIntervalType.stringToField(unit) + } else if (unit == "week") { + dayTimeFields += DayTimeIntervalType.DAY + } else { + assert(unit == "millisecond" || unit == "microsecond") + dayTimeFields += DayTimeIntervalType.SECOND + } + } + if (yearMonthFields.nonEmpty) { + if (dayTimeFields.nonEmpty) { + val literalStr = source(ctx) + throw mixedIntervalUnitsError(literalStr, ctx) + } + Literal( + calendarInterval.months, + YearMonthIntervalType(yearMonthFields.min, yearMonthFields.max) + ) + } else { + Literal( + IntervalUtils.getDuration(calendarInterval, TimeUnit.MICROSECONDS), + DayTimeIntervalType(dayTimeFields.min, dayTimeFields.max)) + } + } + + /** + * Create a [[CalendarInterval]] or ANSI interval literal expression. + * Two syntaxes are supported: + * - multiple unit value pairs, for instance: interval 2 months 2 days. + * - from-to unit, for instance: interval '1-2' year to month. + */ + override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) { + val calendarInterval = parseIntervalLiteral(ctx) + if (ctx.errorCapturingUnitToUnitInterval != null && !conf.legacyIntervalEnabled) { + // Check the `to` unit to distinguish year-month and day-time intervals because + // `CalendarInterval` doesn't have enough info. For instance, new CalendarInterval(0, 0, 0) + // can be derived from INTERVAL '0-0' YEAR TO MONTH as well as from + // INTERVAL '0 00:00:00' DAY TO SECOND. + val fromUnit = + ctx.errorCapturingUnitToUnitInterval.body.from.getText.toLowerCase(Locale.ROOT) + val toUnit = ctx.errorCapturingUnitToUnitInterval.body.to.getText.toLowerCase(Locale.ROOT) + if (toUnit == "month") { + assert(calendarInterval.days == 0 && calendarInterval.microseconds == 0) + val start = YearMonthIntervalType.stringToField(fromUnit) + Literal(calendarInterval.months, YearMonthIntervalType(start, YearMonthIntervalType.MONTH)) + } else { + assert(calendarInterval.months == 0) + val micros = IntervalUtils.getDuration(calendarInterval, TimeUnit.MICROSECONDS) + val start = DayTimeIntervalType.stringToField(fromUnit) + val end = DayTimeIntervalType.stringToField(toUnit) + Literal(micros, DayTimeIntervalType(start, end)) + } + } else if (ctx.errorCapturingMultiUnitsInterval != null && !conf.legacyIntervalEnabled) { + val units = + ctx.errorCapturingMultiUnitsInterval.body.unit.asScala.map( + _.getText.toLowerCase(Locale.ROOT).stripSuffix("s")).toSeq + constructMultiUnitsIntervalLiteral(ctx, calendarInterval, units) + } else { + Literal(calendarInterval, CalendarIntervalType) + } + } + + /** + * Create a [[CalendarInterval]] object + */ + protected def parseIntervalLiteral(ctx: IntervalContext): CalendarInterval = withOrigin(ctx) { + if (ctx.errorCapturingMultiUnitsInterval != null) { + val innerCtx = ctx.errorCapturingMultiUnitsInterval + if (innerCtx.unitToUnitInterval != null) { + throw moreThanOneFromToUnitInIntervalLiteralError( + innerCtx.unitToUnitInterval) + } + visitMultiUnitsInterval(innerCtx.multiUnitsInterval) + } else if (ctx.errorCapturingUnitToUnitInterval != null) { + val innerCtx = ctx.errorCapturingUnitToUnitInterval + if (innerCtx.error1 != null || innerCtx.error2 != null) { + val errorCtx = if (innerCtx.error1 != null) innerCtx.error1 else innerCtx.error2 + throw moreThanOneFromToUnitInIntervalLiteralError(errorCtx) + } + visitUnitToUnitInterval(innerCtx.body) + } else { + throw invalidIntervalLiteralError(ctx) + } + } + + /** + * Creates a [[CalendarInterval]] with multiple unit value pairs, e.g. 1 YEAR 2 DAYS. + */ + override def visitMultiUnitsInterval(ctx: MultiUnitsIntervalContext): CalendarInterval = { + withOrigin(ctx) { + val units = ctx.unit.asScala + val values = ctx.intervalValue().asScala + try { + assert(units.length == values.length) + val kvs = units.indices.map { i => + val u = units(i).getText + val v = if (values(i).STRING() != null) { + val value = string(values(i).STRING()) + // SPARK-32840: For invalid cases, e.g. INTERVAL '1 day 2' hour, + // INTERVAL 'interval 1' day, we need to check ahead before they are concatenated with + // units and become valid ones, e.g. '1 day 2 hour'. + // Ideally, we only ensure the value parts don't contain any units here. + if (value.exists(Character.isLetter)) { + throw invalidIntervalFormError(value, ctx) + } + if (values(i).MINUS() == null) { + value + } else { + value.startsWith("-") match { + case true => value.replaceFirst("-", "") + case false => s"-$value" + } + } + } else { + values(i).getText + } + UTF8String.fromString(" " + v + " " + u) + } + IntervalUtils.stringToInterval(UTF8String.concat(kvs: _*)) + } catch { + case i: IllegalArgumentException => + val e = new ParseException(i.getMessage, ctx) + e.setStackTrace(i.getStackTrace) + throw e + } + } + } + + /** + * Creates a [[CalendarInterval]] with from-to unit, e.g. '2-1' YEAR TO MONTH. + */ + override def visitUnitToUnitInterval(ctx: UnitToUnitIntervalContext): CalendarInterval = { + withOrigin(ctx) { + val value = Option(ctx.intervalValue.STRING).map(string).map { interval => + if (ctx.intervalValue().MINUS() == null) { + interval + } else { + interval.startsWith("-") match { + case true => interval.replaceFirst("-", "") + case false => s"-$interval" + } + } + }.getOrElse { + throw invalidFromToUnitValueError(ctx.intervalValue) + } + try { + val from = ctx.from.getText.toLowerCase(Locale.ROOT) + val to = ctx.to.getText.toLowerCase(Locale.ROOT) + (from, to) match { + case ("year", "month") => + IntervalUtils.fromYearMonthString(value) + case ("day", "hour") | ("day", "minute") | ("day", "second") | ("hour", "minute") | + ("hour", "second") | ("minute", "second") => + IntervalUtils.fromDayTimeString(value, + DayTimeIntervalType.stringToField(from), DayTimeIntervalType.stringToField(to)) + case _ => + throw fromToIntervalUnsupportedError(from, to, ctx) + } + } catch { + // Handle Exceptions thrown by CalendarInterval + case e: IllegalArgumentException => + val pe = new ParseException(e.getMessage, ctx) + pe.setStackTrace(e.getStackTrace) + throw pe + } + } + } + + /* ******************************************************************************************** + * DataType parsing + * ******************************************************************************************** */ + + /** + * Resolve/create a primitive type. + */ + override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { + val dataType = ctx.identifier.getText.toLowerCase(Locale.ROOT) + (dataType, ctx.INTEGER_VALUE().asScala.toList) match { + case ("boolean", Nil) => BooleanType + case ("tinyint" | "byte", Nil) => ByteType + case ("smallint" | "short", Nil) => ShortType + case ("int" | "integer", Nil) => IntegerType + case ("bigint" | "long", Nil) => LongType + case ("float" | "real", Nil) => FloatType + case ("double", Nil) => DoubleType + case ("date", Nil) => DateType + case ("timestamp", Nil) => SQLConf.get.timestampType + // SPARK-36227: Remove TimestampNTZ type support in Spark 3.2 with minimal code changes. + case ("timestamp_ntz", Nil) if isTesting => TimestampNTZType + case ("timestamp_ltz", Nil) if isTesting => TimestampType + case ("string", Nil) => StringType + case ("character" | "char", length :: Nil) => CharType(length.getText.toInt) + case ("varchar", length :: Nil) => VarcharType(length.getText.toInt) + case ("binary", Nil) => BinaryType + case ("decimal" | "dec" | "numeric", Nil) => DecimalType.USER_DEFAULT + case ("decimal" | "dec" | "numeric", precision :: Nil) => + DecimalType(precision.getText.toInt, 0) + case ("decimal" | "dec" | "numeric", precision :: scale :: Nil) => + DecimalType(precision.getText.toInt, scale.getText.toInt) + case ("void", Nil) => NullType + case ("interval", Nil) => CalendarIntervalType + case (dt, params) => + val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt + throw dataTypeUnsupportedError(dtStr, ctx) + } + } + + override def visitYearMonthIntervalDataType(ctx: YearMonthIntervalDataTypeContext): DataType = { + val startStr = ctx.from.getText.toLowerCase(Locale.ROOT) + val start = YearMonthIntervalType.stringToField(startStr) + if (ctx.to != null) { + val endStr = ctx.to.getText.toLowerCase(Locale.ROOT) + val end = YearMonthIntervalType.stringToField(endStr) + if (end <= start) { + throw fromToIntervalUnsupportedError(startStr, endStr, ctx) + } + YearMonthIntervalType(start, end) + } else { + YearMonthIntervalType(start) + } + } + + override def visitDayTimeIntervalDataType(ctx: DayTimeIntervalDataTypeContext): DataType = { + val startStr = ctx.from.getText.toLowerCase(Locale.ROOT) + val start = DayTimeIntervalType.stringToField(startStr) + if (ctx.to != null ) { + val endStr = ctx.to.getText.toLowerCase(Locale.ROOT) + val end = DayTimeIntervalType.stringToField(endStr) + if (end <= start) { + throw fromToIntervalUnsupportedError(startStr, endStr, ctx) + } + DayTimeIntervalType(start, end) + } else { + DayTimeIntervalType(start) + } + } + + /** + * Create a complex DataType. Arrays, Maps and Structures are supported. + */ + override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) { + ctx.complex.getType match { + case HoodieSqlBaseParser.ARRAY => + ArrayType(typedVisit(ctx.dataType(0))) + case HoodieSqlBaseParser.MAP => + MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) + case HoodieSqlBaseParser.STRUCT => + StructType(Option(ctx.complexColTypeList).toSeq.flatMap(visitComplexColTypeList)) + } + } + + /** + * Create top level table schema. + */ + protected def createSchema(ctx: ColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.colType().asScala.map(visitColType).toSeq + } + + /** + * Create a top level [[StructField]] from a column definition. + */ + override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + + val builder = new MetadataBuilder + // Add comment to metadata + Option(commentSpec()).map(visitCommentSpec).foreach { + builder.putString("comment", _) + } + + StructField( + name = colName.getText, + dataType = typedVisit[DataType](ctx.dataType), + nullable = NULL == null, + metadata = builder.build()) + } + + /** + * Create a [[StructType]] from a sequence of [[StructField]]s. + */ + protected def createStructType(ctx: ComplexColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitComplexColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitComplexColTypeList( + ctx: ComplexColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.complexColType().asScala.map(visitComplexColType).toSeq + } + + /** + * Create a [[StructField]] from a column definition. + */ + override def visitComplexColType(ctx: ComplexColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + val structField = StructField( + name = identifier.getText, + dataType = typedVisit(dataType()), + nullable = NULL == null) + Option(commentSpec).map(visitCommentSpec).map(structField.withComment).getOrElse(structField) + } + + /** + * Create a location string. + */ + override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) { + string(ctx.STRING) + } + + /** + * Create an optional location string. + */ + protected def visitLocationSpecList(ctx: java.util.List[LocationSpecContext]): Option[String] = { + ctx.asScala.headOption.map(visitLocationSpec) + } + + /** + * Create a comment string. + */ + override def visitCommentSpec(ctx: CommentSpecContext): String = withOrigin(ctx) { + string(ctx.STRING) + } + + /** + * Create an optional comment string. + */ + protected def visitCommentSpecList(ctx: java.util.List[CommentSpecContext]): Option[String] = { + ctx.asScala.headOption.map(visitCommentSpec) + } + + /** + * Create a [[BucketSpec]]. + */ + override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) { + BucketSpec( + ctx.INTEGER_VALUE.getText.toInt, + visitIdentifierList(ctx.identifierList), + Option(ctx.orderedIdentifierList) + .toSeq + .flatMap(_.orderedIdentifier.asScala) + .map { orderedIdCtx => + Option(orderedIdCtx.ordering).map(_.getText).foreach { dir => + if (dir.toLowerCase(Locale.ROOT) != "asc") { + operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx) + } + } + + orderedIdCtx.ident.getText + }) + } + + /** + * Convert a table property list into a key-value map. + * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]]. + */ + override def visitTablePropertyList( + ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { + val properties = ctx.tableProperty.asScala.map { property => + val key = visitTablePropertyKey(property.key) + val value = visitTablePropertyValue(property.value) + key -> value + } + // Check for duplicate property names. + checkDuplicateKeys(properties.toSeq, ctx) + properties.toMap + } + + /** + * Parse a key-value map from a [[TablePropertyListContext]], assuming all values are specified. + */ + def visitPropertyKeyValues(ctx: TablePropertyListContext): Map[String, String] = { + val props = visitTablePropertyList(ctx) + val badKeys = props.collect { case (key, null) => key } + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values must be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props + } + + /** + * Parse a list of keys from a [[TablePropertyListContext]], assuming no values are specified. + */ + def visitPropertyKeys(ctx: TablePropertyListContext): Seq[String] = { + val props = visitTablePropertyList(ctx) + val badKeys = props.filter { case (_, v) => v != null }.keys + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values should not be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props.keys.toSeq + } + + /** + * A table property key can either be String or a collection of dot separated elements. This + * function extracts the property key based on whether its a string literal or a table property + * identifier. + */ + override def visitTablePropertyKey(key: TablePropertyKeyContext): String = { + if (key.STRING != null) { + string(key.STRING) + } else { + key.getText + } + } + + /** + * A table property value can be String, Integer, Boolean or Decimal. This function extracts + * the property value based on whether its a string, integer, boolean or decimal literal. + */ + override def visitTablePropertyValue(value: TablePropertyValueContext): String = { + if (value == null) { + null + } else if (value.STRING != null) { + string(value.STRING) + } else if (value.booleanValue != null) { + value.getText.toLowerCase(Locale.ROOT) + } else { + value.getText + } + } + + /** + * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal). + */ + type TableHeader = (Seq[String], Boolean, Boolean, Boolean) + + /** + * Type to keep track of table clauses: + * - partition transforms + * - partition columns + * - bucketSpec + * - properties + * - options + * - location + * - comment + * - serde + * + * Note: Partition transforms are based on existing table schema definition. It can be simple + * column names, or functions like `year(date_col)`. Partition columns are column names with data + * types like `i INT`, which should be appended to the existing table schema. + */ + type TableClauses = ( + Seq[Transform], Seq[StructField], Option[BucketSpec], Map[String, String], + Map[String, String], Option[String], Option[String], Option[SerdeInfo]) + + /** + * Validate a create table statement and return the [[TableIdentifier]]. + */ + override def visitCreateTableHeader( + ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { + val temporary = ctx.TEMPORARY != null + val ifNotExists = ctx.EXISTS != null + if (temporary && ifNotExists) { + operationNotAllowed("CREATE TEMPORARY TABLE ... IF NOT EXISTS", ctx) + } + val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText).toSeq + (multipartIdentifier, temporary, ifNotExists, ctx.EXTERNAL != null) + } + + /** + * Validate a replace table statement and return the [[TableIdentifier]]. + */ + override def visitReplaceTableHeader( + ctx: ReplaceTableHeaderContext): TableHeader = withOrigin(ctx) { + val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText).toSeq + (multipartIdentifier, false, false, false) + } + + /** + * Parse a qualified name to a multipart name. + */ + override def visitQualifiedName(ctx: QualifiedNameContext): Seq[String] = withOrigin(ctx) { + ctx.identifier.asScala.map(_.getText).toSeq + } + + /** + * Parse a list of transforms or columns. + */ + override def visitPartitionFieldList( + ctx: PartitionFieldListContext): (Seq[Transform], Seq[StructField]) = withOrigin(ctx) { + val (transforms, columns) = ctx.fields.asScala.map { + case transform: PartitionTransformContext => + (Some(visitPartitionTransform(transform)), None) + case field: PartitionColumnContext => + (None, Some(visitColType(field.colType))) + }.unzip + + (transforms.flatten.toSeq, columns.flatten.toSeq) + } + + override def visitPartitionTransform( + ctx: PartitionTransformContext): Transform = withOrigin(ctx) { + def getFieldReference( + ctx: ApplyTransformContext, + arg: V2Expression): FieldReference = { + lazy val name: String = ctx.identifier.getText + arg match { + case ref: FieldReference => + ref + case nonRef => + throw partitionTransformNotExpectedError(name, nonRef.describe, ctx) + } + } + + def getSingleFieldReference( + ctx: ApplyTransformContext, + arguments: Seq[V2Expression]): FieldReference = { + lazy val name: String = ctx.identifier.getText + if (arguments.size > 1) { + throw tooManyArgumentsForTransformError(name, ctx) + } else if (arguments.isEmpty) { + throw notEnoughArgumentsForTransformError(name, ctx) + } else { + getFieldReference(ctx, arguments.head) + } + } + + ctx.transform match { + case identityCtx: IdentityTransformContext => + IdentityTransform(FieldReference(typedVisit[Seq[String]](identityCtx.qualifiedName))) + + case applyCtx: ApplyTransformContext => + val arguments = applyCtx.argument.asScala.map(visitTransformArgument).toSeq + + applyCtx.identifier.getText match { + case "bucket" => + val numBuckets: Int = arguments.head match { + case LiteralValue(shortValue, ShortType) => + shortValue.asInstanceOf[Short].toInt + case LiteralValue(intValue, IntegerType) => + intValue.asInstanceOf[Int] + case LiteralValue(longValue, LongType) => + longValue.asInstanceOf[Long].toInt + case lit => + throw invalidBucketsNumberError(lit.describe, applyCtx) + } + + val fields = arguments.tail.map(arg => getFieldReference(applyCtx, arg)) + + BucketTransform(LiteralValue(numBuckets, IntegerType), fields) + + case "years" => + YearsTransform(getSingleFieldReference(applyCtx, arguments)) + + case "months" => + MonthsTransform(getSingleFieldReference(applyCtx, arguments)) + + case "days" => + DaysTransform(getSingleFieldReference(applyCtx, arguments)) + + case "hours" => + HoursTransform(getSingleFieldReference(applyCtx, arguments)) + + case name => + ApplyTransform(name, arguments) + } + } + } + + /** + * Parse an argument to a transform. An argument may be a field reference (qualified name) or + * a value literal. + */ + override def visitTransformArgument(ctx: TransformArgumentContext): V2Expression = { + withOrigin(ctx) { + val reference = Option(ctx.qualifiedName) + .map(typedVisit[Seq[String]]) + .map(FieldReference(_)) + val literal = Option(ctx.constant) + .map(typedVisit[Literal]) + .map(lit => LiteralValue(lit.value, lit.dataType)) + reference.orElse(literal) + .getOrElse(throw invalidTransformArgumentError(ctx)) + } + } + + private def cleanNamespaceProperties( + properties: Map[String, String], + ctx: ParserRuleContext): Map[String, String] = withOrigin(ctx) { + import SupportsNamespaces._ + val legacyOn = conf.getConf(SQLConf.LEGACY_PROPERTY_NON_RESERVED) + properties.filter { + case (PROP_LOCATION, _) if !legacyOn => + throw cannotCleanReservedNamespacePropertyError( + PROP_LOCATION, ctx, "please use the LOCATION clause to specify it") + case (PROP_LOCATION, _) => false + case (PROP_OWNER, _) if !legacyOn => + throw cannotCleanReservedNamespacePropertyError( + PROP_OWNER, ctx, "it will be set to the current user") + case (PROP_OWNER, _) => false + case _ => true + } + } + + def cleanTableProperties( + ctx: ParserRuleContext, properties: Map[String, String]): Map[String, String] = { + import TableCatalog._ + val legacyOn = conf.getConf(SQLConf.LEGACY_PROPERTY_NON_RESERVED) + properties.filter { + case (PROP_PROVIDER, _) if !legacyOn => + throw cannotCleanReservedTablePropertyError( + PROP_PROVIDER, ctx, "please use the USING clause to specify it") + case (PROP_PROVIDER, _) => false + case (PROP_LOCATION, _) if !legacyOn => + throw cannotCleanReservedTablePropertyError( + PROP_LOCATION, ctx, "please use the LOCATION clause to specify it") + case (PROP_LOCATION, _) => false + case (PROP_OWNER, _) if !legacyOn => + throw cannotCleanReservedTablePropertyError( + PROP_OWNER, ctx, "it will be set to the current user") + case (PROP_OWNER, _) => false + case _ => true + } + } + + def cleanTableOptions( + ctx: ParserRuleContext, + options: Map[String, String], + location: Option[String]): (Map[String, String], Option[String]) = { + var path = location + val filtered = cleanTableProperties(ctx, options).filter { + case (k, v) if k.equalsIgnoreCase("path") && path.nonEmpty => + throw duplicatedTablePathsFoundError(path.get, v, ctx) + case (k, v) if k.equalsIgnoreCase("path") => + path = Some(v) + false + case _ => true + } + (filtered, path) + } + + /** + * Create a [[SerdeInfo]] for creating tables. + * + * Format: STORED AS (name | INPUTFORMAT input_format OUTPUTFORMAT output_format) + */ + override def visitCreateFileFormat(ctx: CreateFileFormatContext): SerdeInfo = withOrigin(ctx) { + (ctx.fileFormat, ctx.storageHandler) match { + // Expected format: INPUTFORMAT input_format OUTPUTFORMAT output_format + case (c: TableFileFormatContext, null) => + SerdeInfo(formatClasses = Some(FormatClasses(string(c.inFmt), string(c.outFmt)))) + // Expected format: SEQUENCEFILE | TEXTFILE | RCFILE | ORC | PARQUET | AVRO + case (c: GenericFileFormatContext, null) => + SerdeInfo(storedAs = Some(c.identifier.getText)) + case (null, storageHandler) => + operationNotAllowed("STORED BY", ctx) + case _ => + throw storedAsAndStoredByBothSpecifiedError(ctx) + } + } + + /** + * Create a [[SerdeInfo]] used for creating tables. + * + * Example format: + * {{{ + * SERDE serde_name [WITH SERDEPROPERTIES (k1=v1, k2=v2, ...)] + * }}} + * + * OR + * + * {{{ + * DELIMITED [FIELDS TERMINATED BY char [ESCAPED BY char]] + * [COLLECTION ITEMS TERMINATED BY char] + * [MAP KEYS TERMINATED BY char] + * [LINES TERMINATED BY char] + * [NULL DEFINED AS char] + * }}} + */ + def visitRowFormat(ctx: RowFormatContext): SerdeInfo = withOrigin(ctx) { + ctx match { + case serde: RowFormatSerdeContext => visitRowFormatSerde(serde) + case delimited: RowFormatDelimitedContext => visitRowFormatDelimited(delimited) + } + } + + /** + * Create SERDE row format name and properties pair. + */ + override def visitRowFormatSerde(ctx: RowFormatSerdeContext): SerdeInfo = withOrigin(ctx) { + import ctx._ + SerdeInfo( + serde = Some(string(name)), + serdeProperties = Option(tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty)) + } + + /** + * Create a delimited row format properties object. + */ + override def visitRowFormatDelimited( + ctx: RowFormatDelimitedContext): SerdeInfo = withOrigin(ctx) { + // Collect the entries if any. + def entry(key: String, value: Token): Seq[(String, String)] = { + Option(value).toSeq.map(x => key -> string(x)) + } + // TODO we need proper support for the NULL format. + val entries = + entry("field.delim", ctx.fieldsTerminatedBy) ++ + entry("serialization.format", ctx.fieldsTerminatedBy) ++ + entry("escape.delim", ctx.escapedBy) ++ + // The following typo is inherited from Hive... + entry("colelction.delim", ctx.collectionItemsTerminatedBy) ++ + entry("mapkey.delim", ctx.keysTerminatedBy) ++ + Option(ctx.linesSeparatedBy).toSeq.map { token => + val value = string(token) + validate( + value == "\n", + s"LINES TERMINATED BY only supports newline '\\n' right now: $value", + ctx) + "line.delim" -> value + } + SerdeInfo(serdeProperties = entries.toMap) + } + + /** + * Throw a [[ParseException]] if the user specified incompatible SerDes through ROW FORMAT + * and STORED AS. + * + * The following are allowed. Anything else is not: + * ROW FORMAT SERDE ... STORED AS [SEQUENCEFILE | RCFILE | TEXTFILE] + * ROW FORMAT DELIMITED ... STORED AS TEXTFILE + * ROW FORMAT ... STORED AS INPUTFORMAT ... OUTPUTFORMAT ... + */ + protected def validateRowFormatFileFormat( + rowFormatCtx: RowFormatContext, + createFileFormatCtx: CreateFileFormatContext, + parentCtx: ParserRuleContext): Unit = { + if (!(rowFormatCtx == null || createFileFormatCtx == null)) { + (rowFormatCtx, createFileFormatCtx.fileFormat) match { + case (_, ffTable: TableFileFormatContext) => // OK + case (rfSerde: RowFormatSerdeContext, ffGeneric: GenericFileFormatContext) => + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { + case ("sequencefile" | "textfile" | "rcfile") => // OK + case fmt => + operationNotAllowed( + s"ROW FORMAT SERDE is incompatible with format '$fmt', which also specifies a serde", + parentCtx) + } + case (rfDelimited: RowFormatDelimitedContext, ffGeneric: GenericFileFormatContext) => + ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match { + case "textfile" => // OK + case fmt => operationNotAllowed( + s"ROW FORMAT DELIMITED is only compatible with 'textfile', not '$fmt'", parentCtx) + } + case _ => + // should never happen + def str(ctx: ParserRuleContext): String = { + (0 until ctx.getChildCount).map { i => ctx.getChild(i).getText }.mkString(" ") + } + + operationNotAllowed( + s"Unexpected combination of ${str(rowFormatCtx)} and ${str(createFileFormatCtx)}", + parentCtx) + } + } + } + + protected def validateRowFormatFileFormat( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + parentCtx: ParserRuleContext): Unit = { + if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) { + validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx) + } + } + + override def visitCreateTableClauses(ctx: CreateTableClausesContext): TableClauses = { + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx) + checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx) + checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + + if (ctx.skewSpec.size > 0) { + operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx) + } + + val (partTransforms, partCols) = + Option(ctx.partitioning).map(visitPartitionFieldList).getOrElse((Nil, Nil)) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) + val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) + val cleanedProperties = cleanTableProperties(ctx, properties) + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) + val location = visitLocationSpecList(ctx.locationSpec()) + val (cleanedOptions, newLocation) = cleanTableOptions(ctx, options, location) + val comment = visitCommentSpecList(ctx.commentSpec()) + val serdeInfo = + getSerdeInfo(ctx.rowFormat.asScala.toSeq, ctx.createFileFormat.asScala.toSeq, ctx) + (partTransforms, partCols, bucketSpec, cleanedProperties, cleanedOptions, newLocation, comment, + serdeInfo) + } + + protected def getSerdeInfo( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + ctx: ParserRuleContext): Option[SerdeInfo] = { + validateRowFormatFileFormat(rowFormatCtx, createFileFormatCtx, ctx) + val rowFormatSerdeInfo = rowFormatCtx.map(visitRowFormat) + val fileFormatSerdeInfo = createFileFormatCtx.map(visitCreateFileFormat) + (fileFormatSerdeInfo ++ rowFormatSerdeInfo).reduceLeftOption((l, r) => l.merge(r)) + } + + private def partitionExpressions( + partTransforms: Seq[Transform], + partCols: Seq[StructField], + ctx: ParserRuleContext): Seq[Transform] = { + if (partTransforms.nonEmpty) { + if (partCols.nonEmpty) { + val references = partTransforms.map(_.describe()).mkString(", ") + val columns = partCols + .map(field => s"${field.name} ${field.dataType.simpleString}") + .mkString(", ") + operationNotAllowed( + s"""PARTITION BY: Cannot mix partition expressions and partition columns: + |Expressions: $references + |Columns: $columns""".stripMargin, ctx) + + } + partTransforms + } else { + // columns were added to create the schema. convert to column references + partCols.map { column => + IdentityTransform(FieldReference(Seq(column.name))) + } + } + } + + /** + * Create a table, returning a [[CreateTableStatement]] logical plan. + * + * Expected format: + * {{{ + * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name + * [USING table_provider] + * create_table_clauses + * [[AS] select_statement]; + * + * create_table_clauses (order insensitive): + * [PARTITIONED BY (partition_fields)] + * [OPTIONS table_property_list] + * [ROW FORMAT row_format] + * [STORED AS file_format] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [TBLPROPERTIES (property_name=property_value, ...)] + * + * partition_fields: + * col_name, transform(col_name), transform(constant, col_name), ... | + * col_name data_type [NOT NULL] [COMMENT col_comment], ... + * }}} + */ + override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { + val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + + val columns = Option(ctx.colTypeList()).map(visitColTypeList).getOrElse(Nil) + val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) + val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo) = + visitCreateTableClauses(ctx.createTableClauses()) + + if (provider.isDefined && serdeInfo.isDefined) { + operationNotAllowed(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx) + } + + if (temp) { + val asSelect = if (ctx.query == null) "" else " AS ..." + operationNotAllowed( + s"CREATE TEMPORARY TABLE ...$asSelect, use CREATE TEMPORARY VIEW instead", ctx) + } + + val partitioning = partitionExpressions(partTransforms, partCols, ctx) + + Option(ctx.query).map(plan) match { + case Some(_) if columns.nonEmpty => + operationNotAllowed( + "Schema may not be specified in a Create Table As Select (CTAS) statement", + ctx) + + case Some(_) if partCols.nonEmpty => + // non-reference partition columns are not allowed because schema can't be specified + operationNotAllowed( + "Partition column types may not be specified in Create Table As Select (CTAS)", + ctx) + + case Some(query) => + CreateTableAsSelectStatement( + table, query, partitioning, bucketSpec, properties, provider, options, location, comment, + writeOptions = Map.empty, serdeInfo, external = external, ifNotExists = ifNotExists) + + case _ => + // Note: table schema includes both the table columns list and the partition columns + // with data type. + val schema = StructType(columns ++ partCols) + CreateTableStatement(table, schema, partitioning, bucketSpec, properties, provider, + options, location, comment, serdeInfo, external = external, ifNotExists = ifNotExists) + } + } + + /** + * Replace a table, returning a [[ReplaceTableStatement]] logical plan. + * + * Expected format: + * {{{ + * [CREATE OR] REPLACE TABLE [db_name.]table_name + * [USING table_provider] + * replace_table_clauses + * [[AS] select_statement]; + * + * replace_table_clauses (order insensitive): + * [OPTIONS table_property_list] + * [PARTITIONED BY (partition_fields)] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [TBLPROPERTIES (property_name=property_value, ...)] + * + * partition_fields: + * col_name, transform(col_name), transform(constant, col_name), ... | + * col_name data_type [NOT NULL] [COMMENT col_comment], ... + * }}} + */ + override def visitReplaceTable(ctx: ReplaceTableContext): LogicalPlan = withOrigin(ctx) { + val (table, temp, ifNotExists, external) = visitReplaceTableHeader(ctx.replaceTableHeader) + val orCreate = ctx.replaceTableHeader().CREATE() != null + + if (temp) { + val action = if (orCreate) "CREATE OR REPLACE" else "REPLACE" + operationNotAllowed(s"$action TEMPORARY TABLE ..., use $action TEMPORARY VIEW instead.", ctx) + } + + if (external) { + operationNotAllowed("REPLACE EXTERNAL TABLE ...", ctx) + } + + if (ifNotExists) { + operationNotAllowed("REPLACE ... IF NOT EXISTS, use CREATE IF NOT EXISTS instead", ctx) + } + + val (partTransforms, partCols, bucketSpec, properties, options, location, comment, serdeInfo) = + visitCreateTableClauses(ctx.createTableClauses()) + val columns = Option(ctx.colTypeList()).map(visitColTypeList).getOrElse(Nil) + val provider = Option(ctx.tableProvider).map(_.multipartIdentifier.getText) + + if (provider.isDefined && serdeInfo.isDefined) { + operationNotAllowed(s"CREATE TABLE ... USING ... ${serdeInfo.get.describe}", ctx) + } + + val partitioning = partitionExpressions(partTransforms, partCols, ctx) + + Option(ctx.query).map(plan) match { + case Some(_) if columns.nonEmpty => + operationNotAllowed( + "Schema may not be specified in a Replace Table As Select (RTAS) statement", + ctx) + + case Some(_) if partCols.nonEmpty => + // non-reference partition columns are not allowed because schema can't be specified + operationNotAllowed( + "Partition column types may not be specified in Replace Table As Select (RTAS)", + ctx) + + case Some(query) => + ReplaceTableAsSelectStatement(table, query, partitioning, bucketSpec, properties, + provider, options, location, comment, writeOptions = Map.empty, serdeInfo, + orCreate = orCreate) + + case _ => + // Note: table schema includes both the table columns list and the partition columns + // with data type. + val schema = StructType(columns ++ partCols) + ReplaceTableStatement(table, schema, partitioning, bucketSpec, properties, provider, + options, location, comment, serdeInfo, orCreate = orCreate) + } + } + + /** + * Parse new column info from ADD COLUMN into a QualifiedColType. + */ + override def visitQualifiedColTypeWithPosition( + ctx: QualifiedColTypeWithPositionContext): QualifiedColType = withOrigin(ctx) { + val name = typedVisit[Seq[String]](ctx.name) + QualifiedColType( + path = if (name.length > 1) Some(UnresolvedFieldName(name.init)) else None, + colName = name.last, + dataType = typedVisit[DataType](ctx.dataType), + nullable = ctx.NULL == null, + comment = Option(ctx.commentSpec()).map(visitCommentSpec), + position = Option(ctx.colPosition).map( pos => + UnresolvedFieldPosition(typedVisit[ColumnPosition](pos)))) + } + + /** + * Create a [[CacheTable]] or [[CacheTableAsSelect]]. + * + * For example: + * {{{ + * CACHE [LAZY] TABLE multi_part_name + * [OPTIONS tablePropertyList] [[AS] query] + * }}} + */ + override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) { + import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + val query = Option(ctx.query).map(plan) + val relation = createUnresolvedRelation(ctx.multipartIdentifier) + val tableName = relation.multipartIdentifier + if (query.isDefined && tableName.length > 1) { + val catalogAndNamespace = tableName.init + throw addCatalogInCacheTableAsSelectNotAllowedError( + catalogAndNamespace.quoted, ctx) + } + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) + val isLazy = ctx.LAZY != null + if (query.isDefined) { + CacheTableAsSelect(tableName.head, query.get, source(ctx.query()), isLazy, options) + } else { + CacheTable(relation, tableName, isLazy, options) + } + } + + /** + * Create or replace a view. This creates a [[CreateViewStatement]] + * + * For example: + * {{{ + * CREATE [OR REPLACE] [[GLOBAL] TEMPORARY] VIEW [IF NOT EXISTS] multi_part_name + * [(column_name [COMMENT column_comment], ...) ] + * create_view_clauses + * + * AS SELECT ...; + * + * create_view_clauses (order insensitive): + * [COMMENT view_comment] + * [TBLPROPERTIES (property_name = property_value, ...)] + * }}} + */ + override def visitCreateView(ctx: CreateViewContext): LogicalPlan = withOrigin(ctx) { + if (!ctx.identifierList.isEmpty) { + operationNotAllowed("CREATE VIEW ... PARTITIONED ON", ctx) + } + + checkDuplicateClauses(ctx.commentSpec(), "COMMENT", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED ON", ctx) + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + + val userSpecifiedColumns = Option(ctx.identifierCommentList).toSeq.flatMap { icl => + icl.identifierComment.asScala.map { ic => + ic.identifier.getText -> Option(ic.commentSpec()).map(visitCommentSpec) + } + } + + val properties = ctx.tablePropertyList.asScala.headOption.map(visitPropertyKeyValues) + .getOrElse(Map.empty) + if (ctx.TEMPORARY != null && !properties.isEmpty) { + operationNotAllowed("TBLPROPERTIES can't coexist with CREATE TEMPORARY VIEW", ctx) + } + + val viewType = if (ctx.TEMPORARY == null) { + PersistedView + } else if (ctx.GLOBAL != null) { + GlobalTempView + } else { + LocalTempView + } + CreateViewStatement( + visitMultipartIdentifier(ctx.multipartIdentifier), + userSpecifiedColumns, + visitCommentSpecList(ctx.commentSpec()), + properties, + Option(source(ctx.query)), + plan(ctx.query), + ctx.EXISTS != null, + ctx.REPLACE != null, + viewType) + } + + /** + * Alter the query of a view. This creates a [[AlterViewAs]] + * + * For example: + * {{{ + * ALTER VIEW multi_part_name AS SELECT ...; + * }}} + */ + override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) { + AlterViewAs( + createUnresolvedView(ctx.multipartIdentifier, "ALTER VIEW ... AS"), + originalText = source(ctx.query), + query = plan(ctx.query)) + } + + private def alterViewTypeMismatchHint: Option[String] = Some("Please use ALTER TABLE instead.") + + private def alterTableTypeMismatchHint: Option[String] = Some("Please use ALTER VIEW instead.") + + def invalidInsertIntoError(ctx: InsertIntoContext): Throwable = { + new ParseException("Invalid InsertIntoContext", ctx) + } + + def insertOverwriteDirectoryUnsupportedError(ctx: InsertIntoContext): Throwable = { + new ParseException("INSERT OVERWRITE DIRECTORY is not supported", ctx) + } + + def columnAliasInOperationNotAllowedError(op: String, ctx: TableAliasContext): Throwable = { + new ParseException(s"Columns aliases are not allowed in $op.", ctx.identifierList()) + } + + def emptySourceForMergeError(ctx: MergeIntoTableContext): Throwable = { + new ParseException("Empty source for merge: you should specify a source" + + " table/subquery in merge.", ctx.source) + } + + def unrecognizedMatchedActionError(ctx: MatchedClauseContext): Throwable = { + new ParseException(s"Unrecognized matched action: ${ctx.matchedAction().getText}", + ctx.matchedAction()) + } + + def insertedValueNumberNotMatchFieldNumberError(ctx: NotMatchedClauseContext): Throwable = { + new ParseException("The number of inserted values cannot match the fields.", + ctx.notMatchedAction()) + } + + def unrecognizedNotMatchedActionError(ctx: NotMatchedClauseContext): Throwable = { + new ParseException(s"Unrecognized not matched action: ${ctx.notMatchedAction().getText}", + ctx.notMatchedAction()) + } + + def mergeStatementWithoutWhenClauseError(ctx: MergeIntoTableContext): Throwable = { + new ParseException("There must be at least one WHEN clause in a MERGE statement", ctx) + } + + def nonLastMatchedClauseOmitConditionError(ctx: MergeIntoTableContext): Throwable = { + new ParseException("When there are more than one MATCHED clauses in a MERGE " + + "statement, only the last MATCHED clause can omit the condition.", ctx) + } + + def nonLastNotMatchedClauseOmitConditionError(ctx: MergeIntoTableContext): Throwable = { + new ParseException("When there are more than one NOT MATCHED clauses in a MERGE " + + "statement, only the last NOT MATCHED clause can omit the condition.", ctx) + } + + def emptyPartitionKeyError(key: String, ctx: PartitionSpecContext): Throwable = { + new ParseException(s"Found an empty partition key '$key'.", ctx) + } + + def combinationQueryResultClausesUnsupportedError(ctx: QueryOrganizationContext): Throwable = { + new ParseException( + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx) + } + + def distributeByUnsupportedError(ctx: QueryOrganizationContext): Throwable = { + new ParseException("DISTRIBUTE BY is not supported", ctx) + } + + def transformNotSupportQuantifierError(ctx: ParserRuleContext): Throwable = { + new ParseException("TRANSFORM does not support DISTINCT/ALL in inputs", ctx) + } + + def transformWithSerdeUnsupportedError(ctx: ParserRuleContext): Throwable = { + new ParseException("TRANSFORM with serde is only supported in hive mode", ctx) + } + + def lateralWithPivotInFromClauseNotAllowedError(ctx: FromClauseContext): Throwable = { + new ParseException("LATERAL cannot be used together with PIVOT in FROM clause", ctx) + } + + def lateralJoinWithNaturalJoinUnsupportedError(ctx: ParserRuleContext): Throwable = { + new ParseException("LATERAL join with NATURAL join is not supported", ctx) + } + + def lateralJoinWithUsingJoinUnsupportedError(ctx: ParserRuleContext): Throwable = { + new ParseException("LATERAL join with USING join is not supported", ctx) + } + + def unsupportedLateralJoinTypeError(ctx: ParserRuleContext, joinType: String): Throwable = { + new ParseException(s"Unsupported LATERAL join type $joinType", ctx) + } + + def invalidLateralJoinRelationError(ctx: RelationPrimaryContext): Throwable = { + new ParseException(s"LATERAL can only be used with subquery", ctx) + } + + def repetitiveWindowDefinitionError(name: String, ctx: WindowClauseContext): Throwable = { + new ParseException(s"The definition of window '$name' is repetitive", ctx) + } + + def invalidWindowReferenceError(name: String, ctx: WindowClauseContext): Throwable = { + new ParseException(s"Window reference '$name' is not a window specification", ctx) + } + + def cannotResolveWindowReferenceError(name: String, ctx: WindowClauseContext): Throwable = { + new ParseException(s"Cannot resolve window reference '$name'", ctx) + } + + def joinCriteriaUnimplementedError(join: JoinCriteriaContext, ctx: RelationContext): Throwable = { + new ParseException(s"Unimplemented joinCriteria: $join", ctx) + } + + def naturalCrossJoinUnsupportedError(ctx: RelationContext): Throwable = { + new ParseException("NATURAL CROSS JOIN is not supported", ctx) + } + + def emptyInputForTableSampleError(ctx: ParserRuleContext): Throwable = { + new ParseException("TABLESAMPLE does not accept empty inputs.", ctx) + } + + def tableSampleByBytesUnsupportedError(msg: String, ctx: SampleMethodContext): Throwable = { + new ParseException(s"TABLESAMPLE($msg) is not supported", ctx) + } + + def invalidByteLengthLiteralError(bytesStr: String, ctx: SampleByBytesContext): Throwable = { + new ParseException(s"$bytesStr is not a valid byte length literal, " + + "expected syntax: DIGIT+ ('B' | 'K' | 'M' | 'G')", ctx) + } + + def invalidEscapeStringError(ctx: PredicateContext): Throwable = { + new ParseException("Invalid escape string. Escape string must contain only one character.", ctx) + } + + def trimOptionUnsupportedError(trimOption: Int, ctx: TrimContext): Throwable = { + new ParseException("Function trim doesn't support with " + + s"type $trimOption. Please use BOTH, LEADING or TRAILING as trim type", ctx) + } + + def functionNameUnsupportedError(functionName: String, ctx: ParserRuleContext): Throwable = { + new ParseException(s"Unsupported function name '$functionName'", ctx) + } + + def cannotParseValueTypeError( + valueType: String, value: String, ctx: TypeConstructorContext): Throwable = { + new ParseException(s"Cannot parse the $valueType value: $value", ctx) + } + + def cannotParseIntervalValueError(value: String, ctx: TypeConstructorContext): Throwable = { + new ParseException(s"Cannot parse the INTERVAL value: $value", ctx) + } + + def literalValueTypeUnsupportedError( + valueType: String, ctx: TypeConstructorContext): Throwable = { + new ParseException(s"Literals of type '$valueType' are currently not supported.", ctx) + } + + def parsingValueTypeError( + e: IllegalArgumentException, valueType: String, ctx: TypeConstructorContext): Throwable = { + val message = Option(e.getMessage).getOrElse(s"Exception parsing $valueType") + new ParseException(message, ctx) + } + + def invalidNumericLiteralRangeError(rawStrippedQualifier: String, minValue: BigDecimal, + maxValue: BigDecimal, typeName: String, ctx: NumberContext): Throwable = { + new ParseException(s"Numeric literal $rawStrippedQualifier does not " + + s"fit in range [$minValue, $maxValue] for type $typeName", ctx) + } + + def moreThanOneFromToUnitInIntervalLiteralError(ctx: ParserRuleContext): Throwable = { + new ParseException("Can only have a single from-to unit in the interval literal syntax", ctx) + } + + def invalidIntervalLiteralError(ctx: IntervalContext): Throwable = { + new ParseException("at least one time unit should be given for interval literal", ctx) + } + + def invalidIntervalFormError(value: String, ctx: MultiUnitsIntervalContext): Throwable = { + new ParseException("Can only use numbers in the interval value part for" + + s" multiple unit value pairs interval form, but got invalid value: $value", ctx) + } + + def invalidFromToUnitValueError(ctx: IntervalValueContext): Throwable = { + new ParseException("The value of from-to unit must be a string", ctx) + } + + def fromToIntervalUnsupportedError( + from: String, to: String, ctx: ParserRuleContext): Throwable = { + new ParseException(s"Intervals FROM $from TO $to are not supported.", ctx) + } + + def mixedIntervalUnitsError(literal: String, ctx: ParserRuleContext): Throwable = { + new ParseException(s"Cannot mix year-month and day-time fields: $literal", ctx) + } + + def dataTypeUnsupportedError(dataType: String, ctx: PrimitiveDataTypeContext): Throwable = { + new ParseException(s"DataType $dataType is not supported.", ctx) + } + + def partitionTransformNotExpectedError( + name: String, describe: String, ctx: ApplyTransformContext): Throwable = { + new ParseException(s"Expected a column reference for transform $name: $describe", ctx) + } + + def tooManyArgumentsForTransformError(name: String, ctx: ApplyTransformContext): Throwable = { + new ParseException(s"Too many arguments for transform $name", ctx) + } + + def notEnoughArgumentsForTransformError(name: String, ctx: ApplyTransformContext): Throwable = { + new ParseException(s"Not enough arguments for transform $name", ctx) + } + + def invalidBucketsNumberError(describe: String, ctx: ApplyTransformContext): Throwable = { + new ParseException(s"Invalid number of buckets: $describe", ctx) + } + + def invalidTransformArgumentError(ctx: TransformArgumentContext): Throwable = { + new ParseException("Invalid transform argument", ctx) + } + + def cannotCleanReservedNamespacePropertyError( + property: String, ctx: ParserRuleContext, msg: String): Throwable = { + new ParseException(s"$property is a reserved namespace property, $msg.", ctx) + } + + def cannotCleanReservedTablePropertyError( + property: String, ctx: ParserRuleContext, msg: String): Throwable = { + new ParseException(s"$property is a reserved table property, $msg.", ctx) + } + + def duplicatedTablePathsFoundError( + pathOne: String, pathTwo: String, ctx: ParserRuleContext): Throwable = { + new ParseException(s"Duplicated table paths found: '$pathOne' and '$pathTwo'. LOCATION" + + s" and the case insensitive key 'path' in OPTIONS are all used to indicate the custom" + + s" table path, you can only specify one of them.", ctx) + } + + def storedAsAndStoredByBothSpecifiedError(ctx: CreateFileFormatContext): Throwable = { + new ParseException("Expected either STORED AS or STORED BY, not both", ctx) + } + + def operationInHiveStyleCommandUnsupportedError(operation: String, + command: String, ctx: StatementContext, msgOpt: Option[String] = None): Throwable = { + val basicError = s"$operation is not supported in Hive-style $command" + val msg = if (msgOpt.isDefined) { + s"$basicError, ${msgOpt.get}." + } else { + basicError + } + new ParseException(msg, ctx) + } + + def operationNotAllowedError(message: String, ctx: ParserRuleContext): Throwable = { + new ParseException(s"Operation not allowed: $message", ctx) + } + + def computeStatisticsNotExpectedError(ctx: IdentifierContext): Throwable = { + new ParseException(s"Expected `NOSCAN` instead of `${ctx.getText}`", ctx) + } + + def addCatalogInCacheTableAsSelectNotAllowedError( + quoted: String, ctx: CacheTableContext): Throwable = { + new ParseException(s"It is not allowed to add catalog/namespace prefix $quoted to " + + "the table name in CACHE TABLE AS SELECT", ctx) + } + + def showFunctionsUnsupportedError(identifier: String, ctx: IdentifierContext): Throwable = { + new ParseException(s"SHOW $identifier FUNCTIONS not supported", ctx) + } + + def duplicateCteDefinitionNamesError(duplicateNames: String, ctx: CtesContext): Throwable = { + new ParseException(s"CTE definition can't have duplicate names: $duplicateNames.", ctx) + } + + def sqlStatementUnsupportedError(sqlText: String, position: Origin): Throwable = { + new ParseException(Option(sqlText), "Unsupported SQL statement", position, position) + } + + def unquotedIdentifierError(ident: String, ctx: ErrorIdentContext): Throwable = { + new ParseException(s"Possibly unquoted identifier $ident detected. " + + s"Please consider quoting it with back-quotes as `$ident`", ctx) + } + + def duplicateClausesError(clauseName: String, ctx: ParserRuleContext): Throwable = { + new ParseException(s"Found duplicate clauses: $clauseName", ctx) + } + + def duplicateKeysError(key: String, ctx: ParserRuleContext): Throwable = { + // Found duplicate keys '$key' + new ParseException(errorClass = "DUPLICATE_KEY", messageParameters = Array(key), ctx) + } + + def intervalValueOutOfRangeError(ctx: IntervalContext): Throwable = { + new ParseException("The interval value must be in the range of [-18, +18] hours" + + " with second precision", ctx) + } + + def createTempTableNotSpecifyProviderError(ctx: CreateTableContext): Throwable = { + new ParseException("CREATE TEMPORARY TABLE without a provider is not allowed.", ctx) + } + + def useDefinedRecordReaderOrWriterClassesError(ctx: ParserRuleContext): Throwable = { + new ParseException( + "Unsupported operation: Used defined record reader/writer classes.", ctx) + } + + def directoryPathAndOptionsPathBothSpecifiedError(ctx: InsertOverwriteDirContext): Throwable = { + new ParseException( + "Directory path and 'path' in OPTIONS should be specified one, but not both", ctx) + } + + def unsupportedLocalFileSchemeError(ctx: InsertOverwriteDirContext): Throwable = { + new ParseException("LOCAL is supported only with file: scheme", ctx) + } + + def invalidGroupingSetError(element: String, ctx: GroupingAnalyticsContext): Throwable = { + new ParseException(s"Empty set in $element grouping sets is not supported.", ctx) + } +} + +/** + * A container for holding named common table expressions (CTEs) and a query plan. + * This operator will be removed during analysis and the relations will be substituted into child. + * + * @param child The final query of this CTE. + * @param cteRelations A sequence of pair (alias, the CTE definition) that this CTE defined + * Each CTE can see the base tables and the previously defined CTEs only. + */ +case class With(child: LogicalPlan, cteRelations: Seq[(String, SubqueryAlias)]) extends UnaryNode { + override def output: Seq[Attribute] = child.output + + override def simpleString(maxFields: Int): String = { + val cteAliases = truncatedString(cteRelations.map(_._1), "[", ", ", "]", maxFields) + s"CTE $cteAliases" + } + + override def innerChildren: Seq[LogicalPlan] = cteRelations.map(_._2) + + def withNewChildInternal(newChild: LogicalPlan): LogicalPlan = this +} diff --git a/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlParser.scala b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlParser.scala new file mode 100644 index 0000000000000..59ef8dfe0969b --- /dev/null +++ b/hudi-spark-datasource/hudi-spark3/src/main/scala/org/apache/spark/sql/parser/HoodieSpark3_2ExtendedSqlParser.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parser + +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException} +import org.antlr.v4.runtime.tree.TerminalNodeImpl +import org.apache.hudi.spark.sql.parser.HoodieSqlBaseParser.{NonReservedContext, QuotedIdentifierContext} +import org.apache.hudi.spark.sql.parser.{HoodieSqlBaseBaseListener, HoodieSqlBaseLexer, HoodieSqlBaseParser} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.{ParseErrorListener, ParseException, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, SparkSession} + +class HoodieSpark3_2ExtendedSqlParser(session: SparkSession, delegate: ParserInterface) + extends ParserInterface with Logging { + + private lazy val conf = session.sqlContext.conf + private lazy val builder = new HoodieSpark3_2ExtendedSqlAstBuilder(conf, delegate) + + override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => + builder.visit(parser.singleStatement()) match { + case plan: LogicalPlan => plan + case _=> delegate.parsePlan(sqlText) + } + } + + override def parseExpression(sqlText: String): Expression = delegate.parseExpression(sqlText) + + override def parseTableIdentifier(sqlText: String): TableIdentifier = + delegate.parseTableIdentifier(sqlText) + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = + delegate.parseFunctionIdentifier(sqlText) + + override def parseTableSchema(sqlText: String): StructType = delegate.parseTableSchema(sqlText) + + override def parseDataType(sqlText: String): DataType = delegate.parseDataType(sqlText) + + protected def parse[T](command: String)(toResult: HoodieSqlBaseParser => T): T = { + logDebug(s"Parsing command: $command") + + val lexer = new HoodieSqlBaseLexer(new UpperCaseCharStream(CharStreams.fromString(command))) + lexer.removeErrorListeners() + lexer.addErrorListener(ParseErrorListener) + + val tokenStream = new CommonTokenStream(lexer) + val parser = new HoodieSqlBaseParser(tokenStream) + parser.addParseListener(PostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(ParseErrorListener) +// parser.legacy_setops_precedence_enabled = conf.setOpsPrecedenceEnforced + parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled + parser.SQL_standard_keyword_behavior = conf.ansiEnabled + + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) + } + catch { + case e: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.seek(0) // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) + } + } + catch { + case e: ParseException if e.command.isDefined => + throw e + case e: ParseException => + throw e.withCommand(command) + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new ParseException(Option(command), e.message, position, position) + } + } + + override def parseMultipartIdentifier(sqlText: String): Seq[String] = { + delegate.parseMultipartIdentifier(sqlText) + } +} + +/** + * Fork from `org.apache.spark.sql.catalyst.parser.UpperCaseCharStream`. + */ +class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { + override def consume(): Unit = wrapped.consume + override def getSourceName(): String = wrapped.getSourceName + override def index(): Int = wrapped.index + override def mark(): Int = wrapped.mark + override def release(marker: Int): Unit = wrapped.release(marker) + override def seek(where: Int): Unit = wrapped.seek(where) + override def size(): Int = wrapped.size + + override def getText(interval: Interval): String = { + // ANTLR 4.7's CodePointCharStream implementations have bugs when + // getText() is called with an empty stream, or intervals where + // the start > end. See + // https://github.com/antlr/antlr4/commit/ac9f7530 for one fix + // that is not yet in a released ANTLR artifact. + if (size() > 0 && (interval.b - interval.a >= 0)) { + wrapped.getText(interval) + } else { + "" + } + } + // scalastyle:off + override def LA(i: Int): Int = { + // scalastyle:on + val la = wrapped.LA(i) + if (la == 0 || la == IntStream.EOF) la + else Character.toUpperCase(la) + } +} + +/** + * Fork from `org.apache.spark.sql.catalyst.parser.PostProcessor`. + */ +case object PostProcessor extends HoodieSqlBaseBaseListener { + + /** Remove the back ticks from an Identifier. */ + override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = { + replaceTokenByIdentifier(ctx, 1) { token => + // Remove the double back ticks in the string. + token.setText(token.getText.replace("``", "`")) + token + } + } + + /** Treat non-reserved keywords as Identifiers. */ + override def exitNonReserved(ctx: NonReservedContext): Unit = { + replaceTokenByIdentifier(ctx, 0)(identity) + } + + private def replaceTokenByIdentifier( + ctx: ParserRuleContext, + stripMargins: Int)( + f: CommonToken => CommonToken = identity): Unit = { + val parent = ctx.getParent + parent.removeLastChild() + val token = ctx.getChild(0).getPayload.asInstanceOf[Token] + val newToken = new CommonToken( + new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), + HoodieSqlBaseParser.IDENTIFIER, + token.getChannel, + token.getStartIndex + stripMargins, + token.getStopIndex - stripMargins) + parent.addChild(new TerminalNodeImpl(f(newToken))) + } +} diff --git a/pom.xml b/pom.xml index a0f7c2c6e303e..c277c3037b0d2 100644 --- a/pom.xml +++ b/pom.xml @@ -120,6 +120,8 @@ 1.14.3 2.4.4 3.2.1 + 3.1.2 + 3.2.1 hudi-spark2 hudi-spark2-common 1.8.2 @@ -1584,6 +1586,7 @@ spark3 + ${spark3.2.version} ${spark3.version} ${spark3.version} ${scala12.version} @@ -1612,7 +1615,7 @@ spark3.1.x - 3.1.2 + ${spark3.1.version} ${spark3.version} ${spark3.version} ${scala12.version} diff --git a/style/scalastyle.xml b/style/scalastyle.xml index 74d7b9d73a203..a1b4cdbb6dafa 100644 --- a/style/scalastyle.xml +++ b/style/scalastyle.xml @@ -27,7 +27,7 @@ - + @@ -113,7 +113,7 @@ - +